优化初始化判断

This commit is contained in:
M09Ic 2022-10-27 23:40:15 +08:00
parent 4a4faf4e66
commit 6cc0a71dae
6 changed files with 32 additions and 24 deletions

View File

@ -12,11 +12,14 @@ import (
func NewBaseline(u, host string, resp *ihttp.Response) *baseline {
bl := &baseline{
Url: u,
Host: host,
Status: resp.StatusCode(),
IsValid: true,
}
if resp.ClientType == ihttp.STANDARD {
bl.Host = host
}
bl.Body = resp.Body()
bl.BodyLength = resp.ContentLength()
bl.Header = resp.Header()
@ -29,11 +32,14 @@ func NewBaseline(u, host string, resp *ihttp.Response) *baseline {
func NewInvalidBaseline(u, host string, resp *ihttp.Response) *baseline {
bl := &baseline{
Url: u,
Host: host,
Status: resp.StatusCode(),
IsValid: false,
}
if resp.ClientType == ihttp.STANDARD {
bl.Host = host
}
bl.RedirectURL = string(resp.GetHeader("Location"))
return bl
@ -80,8 +86,6 @@ func (bl *baseline) Equal(other *baseline) bool {
// 如果body length相等且md5相等, 则说明是同一个页面
if bl.BodyMd5 == parsers.Md5Hash(other.Raw) {
return true
} else {
return true
}
}
@ -131,7 +135,6 @@ func (bl *baseline) Get(key string) string {
return bl.Frameworks.ToString()
default:
return ""
}
}
@ -142,11 +145,14 @@ func (bl *baseline) Additional(key string) string {
return " "
}
}
func (bl *baseline) String() string {
var line strings.Builder
//line.WriteString("[+] ")
line.WriteString(bl.Url)
if bl.Host != "" {
line.WriteString(" (" + bl.Host + ")")
}
line.WriteString(" - ")
line.WriteString(strconv.Itoa(bl.Status))
line.WriteString(" - ")

View File

@ -86,25 +86,24 @@ func NewPool(ctx context.Context, config *pkg.Config, outputCh chan *baseline) (
pool.failedCount++
bl = &baseline{Url: pool.BaseURL + unit.path, Err: reqerr}
} else {
pool.failedCount = 0
pool.failedCount = 0 // 如果后续访问正常, 重置错误次数
if err = pool.PreCompare(resp); err == nil || unit.source == CheckSource {
// 通过预对比跳过一些无用数据, 减少性能消耗
bl = NewBaseline(req.URI(), req.Host(), resp)
} else {
bl = NewInvalidBaseline(req.URI(), req.Host(), resp)
}
bl.Err = reqerr
}
switch unit.source {
case CheckSource:
logs.Log.Debugf("check: " + bl.String())
if pool.base == nil {
//初次check覆盖baseline
case InitSource:
pool.base = bl
pool.initwg.Done()
} else if bl.Err != nil {
logs.Log.Warn("maybe ip banned by waf")
return
case CheckSource:
logs.Log.Debugf("check: " + bl.String())
if bl.Err != nil {
logs.Log.Warnf("maybe ip has banned by waf, break (%d/%d), error: %s", pool.failedCount, breakThreshold, bl.Err.Error())
} else if !pool.base.Equal(bl) {
logs.Log.Warn("maybe trigger risk control")
}
@ -155,12 +154,12 @@ type Pool struct {
func (p *Pool) Init() error {
p.initwg.Add(1)
p.check()
p.pool.Invoke(newUnit(pkg.RandHost(), InitSource))
p.initwg.Wait()
// todo 分析baseline
// 检测基本访问能力
if p.base != nil && p.base.Err != nil {
if p.base.Err != nil {
p.cancel()
return p.base.Err
}
@ -255,10 +254,10 @@ func (p *Pool) buildPathRequest(path string) (*ihttp.Request, error) {
if p.Config.ClientType == ihttp.FAST {
req := fasthttp.AcquireRequest()
req.SetRequestURI(p.BaseURL + path)
return &ihttp.Request{FastRequest: req}, nil
return &ihttp.Request{FastRequest: req, ClientType: p.ClientType}, nil
} else {
req, err := http.NewRequest("GET", p.BaseURL+path, nil)
return &ihttp.Request{StandardRequest: req}, err
return &ihttp.Request{StandardRequest: req, ClientType: p.ClientType}, err
}
}
@ -267,10 +266,10 @@ func (p *Pool) buildHostRequest(host string) (*ihttp.Request, error) {
req := fasthttp.AcquireRequest()
req.SetRequestURI(p.BaseURL)
req.SetHost(host)
return &ihttp.Request{FastRequest: req}, nil
return &ihttp.Request{FastRequest: req, ClientType: p.ClientType}, nil
} else {
req, err := http.NewRequest("GET", p.BaseURL, nil)
req.Host = host
return &ihttp.Request{StandardRequest: req}, err
return &ihttp.Request{StandardRequest: req, ClientType: p.ClientType}, err
}
}

View File

@ -25,6 +25,7 @@ type sourceType int
const (
CheckSource sourceType = iota + 1
InitSource
WordSource
WafSource
)

View File

@ -75,10 +75,10 @@ func (c *Client) StandardDo(ctx context.Context, req *http.Request) (*http.Respo
func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) {
if c.fastClient != nil {
resp, err := c.FastDo(ctx, req.FastRequest)
return &Response{FastResponse: resp}, err
return &Response{FastResponse: resp, ClientType: FAST}, err
} else if c.standardClient != nil {
resp, err := c.StandardDo(ctx, req.StandardRequest)
return &Response{StandardResponse: resp}, err
return &Response{StandardResponse: resp, ClientType: STANDARD}, err
} else {
return nil, fmt.Errorf("not found client")
}

View File

@ -8,6 +8,7 @@ import (
type Request struct {
StandardRequest *http.Request
FastRequest *fasthttp.Request
ClientType int
}
func (r *Request) URI() string {

View File

@ -11,6 +11,7 @@ import (
type Response struct {
StandardResponse *http.Response
FastResponse *fasthttp.Response
ClientType int
}
func (r *Response) StatusCode() int {