优化初始化判断

This commit is contained in:
M09Ic 2022-10-27 23:40:15 +08:00
parent 4a4faf4e66
commit 6cc0a71dae
6 changed files with 32 additions and 24 deletions

View File

@ -12,11 +12,14 @@ import (
func NewBaseline(u, host string, resp *ihttp.Response) *baseline { func NewBaseline(u, host string, resp *ihttp.Response) *baseline {
bl := &baseline{ bl := &baseline{
Url: u, Url: u,
Host: host,
Status: resp.StatusCode(), Status: resp.StatusCode(),
IsValid: true, IsValid: true,
} }
if resp.ClientType == ihttp.STANDARD {
bl.Host = host
}
bl.Body = resp.Body() bl.Body = resp.Body()
bl.BodyLength = resp.ContentLength() bl.BodyLength = resp.ContentLength()
bl.Header = resp.Header() bl.Header = resp.Header()
@ -29,11 +32,14 @@ func NewBaseline(u, host string, resp *ihttp.Response) *baseline {
func NewInvalidBaseline(u, host string, resp *ihttp.Response) *baseline { func NewInvalidBaseline(u, host string, resp *ihttp.Response) *baseline {
bl := &baseline{ bl := &baseline{
Url: u, Url: u,
Host: host,
Status: resp.StatusCode(), Status: resp.StatusCode(),
IsValid: false, IsValid: false,
} }
if resp.ClientType == ihttp.STANDARD {
bl.Host = host
}
bl.RedirectURL = string(resp.GetHeader("Location")) bl.RedirectURL = string(resp.GetHeader("Location"))
return bl return bl
@ -80,8 +86,6 @@ func (bl *baseline) Equal(other *baseline) bool {
// 如果body length相等且md5相等, 则说明是同一个页面 // 如果body length相等且md5相等, 则说明是同一个页面
if bl.BodyMd5 == parsers.Md5Hash(other.Raw) { if bl.BodyMd5 == parsers.Md5Hash(other.Raw) {
return true return true
} else {
return true
} }
} }
@ -131,22 +135,24 @@ func (bl *baseline) Get(key string) string {
return bl.Frameworks.ToString() return bl.Frameworks.ToString()
default: default:
return "" return ""
} }
} }
func (bl *baseline) Additional(key string) string { func (bl *baseline) Additional(key string) string {
if v := bl.Get(key); v != "" { if v := bl.Get(key); v != "" {
return "[" + v + "]" return "[" + v + "] "
} else { } else {
return "" return " "
} }
} }
func (bl *baseline) String() string { func (bl *baseline) String() string {
var line strings.Builder var line strings.Builder
//line.WriteString("[+] ") //line.WriteString("[+] ")
line.WriteString(bl.Url) line.WriteString(bl.Url)
line.WriteString(" (" + bl.Host + ")") if bl.Host != "" {
line.WriteString(" (" + bl.Host + ")")
}
line.WriteString(" - ") line.WriteString(" - ")
line.WriteString(strconv.Itoa(bl.Status)) line.WriteString(strconv.Itoa(bl.Status))
line.WriteString(" - ") line.WriteString(" - ")

View File

@ -86,25 +86,24 @@ func NewPool(ctx context.Context, config *pkg.Config, outputCh chan *baseline) (
pool.failedCount++ pool.failedCount++
bl = &baseline{Url: pool.BaseURL + unit.path, Err: reqerr} bl = &baseline{Url: pool.BaseURL + unit.path, Err: reqerr}
} else { } else {
pool.failedCount = 0 pool.failedCount = 0 // 如果后续访问正常, 重置错误次数
if err = pool.PreCompare(resp); err == nil || unit.source == CheckSource { if err = pool.PreCompare(resp); err == nil || unit.source == CheckSource {
// 通过预对比跳过一些无用数据, 减少性能消耗 // 通过预对比跳过一些无用数据, 减少性能消耗
bl = NewBaseline(req.URI(), req.Host(), resp) bl = NewBaseline(req.URI(), req.Host(), resp)
} else { } else {
bl = NewInvalidBaseline(req.URI(), req.Host(), resp) bl = NewInvalidBaseline(req.URI(), req.Host(), resp)
} }
bl.Err = reqerr
} }
switch unit.source { switch unit.source {
case InitSource:
pool.base = bl
pool.initwg.Done()
return
case CheckSource: case CheckSource:
logs.Log.Debugf("check: " + bl.String()) logs.Log.Debugf("check: " + bl.String())
if pool.base == nil { if bl.Err != nil {
//初次check覆盖baseline logs.Log.Warnf("maybe ip has banned by waf, break (%d/%d), error: %s", pool.failedCount, breakThreshold, bl.Err.Error())
pool.base = bl
pool.initwg.Done()
} else if bl.Err != nil {
logs.Log.Warn("maybe ip banned by waf")
} else if !pool.base.Equal(bl) { } else if !pool.base.Equal(bl) {
logs.Log.Warn("maybe trigger risk control") logs.Log.Warn("maybe trigger risk control")
} }
@ -155,12 +154,12 @@ type Pool struct {
func (p *Pool) Init() error { func (p *Pool) Init() error {
p.initwg.Add(1) p.initwg.Add(1)
p.check() p.pool.Invoke(newUnit(pkg.RandHost(), InitSource))
p.initwg.Wait() p.initwg.Wait()
// todo 分析baseline // todo 分析baseline
// 检测基本访问能力 // 检测基本访问能力
if p.base != nil && p.base.Err != nil { if p.base.Err != nil {
p.cancel() p.cancel()
return p.base.Err return p.base.Err
} }
@ -255,10 +254,10 @@ func (p *Pool) buildPathRequest(path string) (*ihttp.Request, error) {
if p.Config.ClientType == ihttp.FAST { if p.Config.ClientType == ihttp.FAST {
req := fasthttp.AcquireRequest() req := fasthttp.AcquireRequest()
req.SetRequestURI(p.BaseURL + path) req.SetRequestURI(p.BaseURL + path)
return &ihttp.Request{FastRequest: req}, nil return &ihttp.Request{FastRequest: req, ClientType: p.ClientType}, nil
} else { } else {
req, err := http.NewRequest("GET", p.BaseURL+path, nil) req, err := http.NewRequest("GET", p.BaseURL+path, nil)
return &ihttp.Request{StandardRequest: req}, err return &ihttp.Request{StandardRequest: req, ClientType: p.ClientType}, err
} }
} }
@ -267,10 +266,10 @@ func (p *Pool) buildHostRequest(host string) (*ihttp.Request, error) {
req := fasthttp.AcquireRequest() req := fasthttp.AcquireRequest()
req.SetRequestURI(p.BaseURL) req.SetRequestURI(p.BaseURL)
req.SetHost(host) req.SetHost(host)
return &ihttp.Request{FastRequest: req}, nil return &ihttp.Request{FastRequest: req, ClientType: p.ClientType}, nil
} else { } else {
req, err := http.NewRequest("GET", p.BaseURL, nil) req, err := http.NewRequest("GET", p.BaseURL, nil)
req.Host = host req.Host = host
return &ihttp.Request{StandardRequest: req}, err return &ihttp.Request{StandardRequest: req, ClientType: p.ClientType}, err
} }
} }

View File

@ -25,6 +25,7 @@ type sourceType int
const ( const (
CheckSource sourceType = iota + 1 CheckSource sourceType = iota + 1
InitSource
WordSource WordSource
WafSource WafSource
) )

View File

@ -75,10 +75,10 @@ func (c *Client) StandardDo(ctx context.Context, req *http.Request) (*http.Respo
func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) { func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) {
if c.fastClient != nil { if c.fastClient != nil {
resp, err := c.FastDo(ctx, req.FastRequest) resp, err := c.FastDo(ctx, req.FastRequest)
return &Response{FastResponse: resp}, err return &Response{FastResponse: resp, ClientType: FAST}, err
} else if c.standardClient != nil { } else if c.standardClient != nil {
resp, err := c.StandardDo(ctx, req.StandardRequest) resp, err := c.StandardDo(ctx, req.StandardRequest)
return &Response{StandardResponse: resp}, err return &Response{StandardResponse: resp, ClientType: STANDARD}, err
} else { } else {
return nil, fmt.Errorf("not found client") return nil, fmt.Errorf("not found client")
} }

View File

@ -8,6 +8,7 @@ import (
type Request struct { type Request struct {
StandardRequest *http.Request StandardRequest *http.Request
FastRequest *fasthttp.Request FastRequest *fasthttp.Request
ClientType int
} }
func (r *Request) URI() string { func (r *Request) URI() string {

View File

@ -11,6 +11,7 @@ import (
type Response struct { type Response struct {
StandardResponse *http.Response StandardResponse *http.Response
FastResponse *fasthttp.Response FastResponse *fasthttp.Response
ClientType int
} }
func (r *Response) StatusCode() int { func (r *Response) StatusCode() int {