mirror of
https://github.com/chainreactors/spray.git
synced 2025-09-15 11:40:13 +00:00
优化初始化判断
This commit is contained in:
parent
4a4faf4e66
commit
6cc0a71dae
@ -12,11 +12,14 @@ import (
|
||||
func NewBaseline(u, host string, resp *ihttp.Response) *baseline {
|
||||
bl := &baseline{
|
||||
Url: u,
|
||||
Host: host,
|
||||
Status: resp.StatusCode(),
|
||||
IsValid: true,
|
||||
}
|
||||
|
||||
if resp.ClientType == ihttp.STANDARD {
|
||||
bl.Host = host
|
||||
}
|
||||
|
||||
bl.Body = resp.Body()
|
||||
bl.BodyLength = resp.ContentLength()
|
||||
bl.Header = resp.Header()
|
||||
@ -29,11 +32,14 @@ func NewBaseline(u, host string, resp *ihttp.Response) *baseline {
|
||||
func NewInvalidBaseline(u, host string, resp *ihttp.Response) *baseline {
|
||||
bl := &baseline{
|
||||
Url: u,
|
||||
Host: host,
|
||||
Status: resp.StatusCode(),
|
||||
IsValid: false,
|
||||
}
|
||||
|
||||
if resp.ClientType == ihttp.STANDARD {
|
||||
bl.Host = host
|
||||
}
|
||||
|
||||
bl.RedirectURL = string(resp.GetHeader("Location"))
|
||||
|
||||
return bl
|
||||
@ -80,8 +86,6 @@ func (bl *baseline) Equal(other *baseline) bool {
|
||||
// 如果body length相等且md5相等, 则说明是同一个页面
|
||||
if bl.BodyMd5 == parsers.Md5Hash(other.Raw) {
|
||||
return true
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@ -131,7 +135,6 @@ func (bl *baseline) Get(key string) string {
|
||||
return bl.Frameworks.ToString()
|
||||
default:
|
||||
return ""
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@ -142,11 +145,14 @@ func (bl *baseline) Additional(key string) string {
|
||||
return " "
|
||||
}
|
||||
}
|
||||
|
||||
func (bl *baseline) String() string {
|
||||
var line strings.Builder
|
||||
//line.WriteString("[+] ")
|
||||
line.WriteString(bl.Url)
|
||||
if bl.Host != "" {
|
||||
line.WriteString(" (" + bl.Host + ")")
|
||||
}
|
||||
line.WriteString(" - ")
|
||||
line.WriteString(strconv.Itoa(bl.Status))
|
||||
line.WriteString(" - ")
|
||||
|
@ -86,25 +86,24 @@ func NewPool(ctx context.Context, config *pkg.Config, outputCh chan *baseline) (
|
||||
pool.failedCount++
|
||||
bl = &baseline{Url: pool.BaseURL + unit.path, Err: reqerr}
|
||||
} else {
|
||||
pool.failedCount = 0
|
||||
pool.failedCount = 0 // 如果后续访问正常, 重置错误次数
|
||||
if err = pool.PreCompare(resp); err == nil || unit.source == CheckSource {
|
||||
// 通过预对比跳过一些无用数据, 减少性能消耗
|
||||
bl = NewBaseline(req.URI(), req.Host(), resp)
|
||||
} else {
|
||||
bl = NewInvalidBaseline(req.URI(), req.Host(), resp)
|
||||
}
|
||||
bl.Err = reqerr
|
||||
}
|
||||
|
||||
switch unit.source {
|
||||
case CheckSource:
|
||||
logs.Log.Debugf("check: " + bl.String())
|
||||
if pool.base == nil {
|
||||
//初次check覆盖baseline
|
||||
case InitSource:
|
||||
pool.base = bl
|
||||
pool.initwg.Done()
|
||||
} else if bl.Err != nil {
|
||||
logs.Log.Warn("maybe ip banned by waf")
|
||||
return
|
||||
case CheckSource:
|
||||
logs.Log.Debugf("check: " + bl.String())
|
||||
if bl.Err != nil {
|
||||
logs.Log.Warnf("maybe ip has banned by waf, break (%d/%d), error: %s", pool.failedCount, breakThreshold, bl.Err.Error())
|
||||
} else if !pool.base.Equal(bl) {
|
||||
logs.Log.Warn("maybe trigger risk control")
|
||||
}
|
||||
@ -155,12 +154,12 @@ type Pool struct {
|
||||
|
||||
func (p *Pool) Init() error {
|
||||
p.initwg.Add(1)
|
||||
p.check()
|
||||
p.pool.Invoke(newUnit(pkg.RandHost(), InitSource))
|
||||
p.initwg.Wait()
|
||||
// todo 分析baseline
|
||||
// 检测基本访问能力
|
||||
|
||||
if p.base != nil && p.base.Err != nil {
|
||||
if p.base.Err != nil {
|
||||
p.cancel()
|
||||
return p.base.Err
|
||||
}
|
||||
@ -255,10 +254,10 @@ func (p *Pool) buildPathRequest(path string) (*ihttp.Request, error) {
|
||||
if p.Config.ClientType == ihttp.FAST {
|
||||
req := fasthttp.AcquireRequest()
|
||||
req.SetRequestURI(p.BaseURL + path)
|
||||
return &ihttp.Request{FastRequest: req}, nil
|
||||
return &ihttp.Request{FastRequest: req, ClientType: p.ClientType}, nil
|
||||
} else {
|
||||
req, err := http.NewRequest("GET", p.BaseURL+path, nil)
|
||||
return &ihttp.Request{StandardRequest: req}, err
|
||||
return &ihttp.Request{StandardRequest: req, ClientType: p.ClientType}, err
|
||||
}
|
||||
}
|
||||
|
||||
@ -267,10 +266,10 @@ func (p *Pool) buildHostRequest(host string) (*ihttp.Request, error) {
|
||||
req := fasthttp.AcquireRequest()
|
||||
req.SetRequestURI(p.BaseURL)
|
||||
req.SetHost(host)
|
||||
return &ihttp.Request{FastRequest: req}, nil
|
||||
return &ihttp.Request{FastRequest: req, ClientType: p.ClientType}, nil
|
||||
} else {
|
||||
req, err := http.NewRequest("GET", p.BaseURL, nil)
|
||||
req.Host = host
|
||||
return &ihttp.Request{StandardRequest: req}, err
|
||||
return &ihttp.Request{StandardRequest: req, ClientType: p.ClientType}, err
|
||||
}
|
||||
}
|
||||
|
@ -25,6 +25,7 @@ type sourceType int
|
||||
|
||||
const (
|
||||
CheckSource sourceType = iota + 1
|
||||
InitSource
|
||||
WordSource
|
||||
WafSource
|
||||
)
|
||||
|
@ -75,10 +75,10 @@ func (c *Client) StandardDo(ctx context.Context, req *http.Request) (*http.Respo
|
||||
func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) {
|
||||
if c.fastClient != nil {
|
||||
resp, err := c.FastDo(ctx, req.FastRequest)
|
||||
return &Response{FastResponse: resp}, err
|
||||
return &Response{FastResponse: resp, ClientType: FAST}, err
|
||||
} else if c.standardClient != nil {
|
||||
resp, err := c.StandardDo(ctx, req.StandardRequest)
|
||||
return &Response{StandardResponse: resp}, err
|
||||
return &Response{StandardResponse: resp, ClientType: STANDARD}, err
|
||||
} else {
|
||||
return nil, fmt.Errorf("not found client")
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
type Request struct {
|
||||
StandardRequest *http.Request
|
||||
FastRequest *fasthttp.Request
|
||||
ClientType int
|
||||
}
|
||||
|
||||
func (r *Request) URI() string {
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
type Response struct {
|
||||
StandardResponse *http.Response
|
||||
FastResponse *fasthttp.Response
|
||||
ClientType int
|
||||
}
|
||||
|
||||
func (r *Response) StatusCode() int {
|
||||
|
Loading…
x
Reference in New Issue
Block a user