diff --git a/internal/baseline.go b/internal/baseline.go index 0375323..0442430 100644 --- a/internal/baseline.go +++ b/internal/baseline.go @@ -12,11 +12,14 @@ import ( func NewBaseline(u, host string, resp *ihttp.Response) *baseline { bl := &baseline{ Url: u, - Host: host, Status: resp.StatusCode(), IsValid: true, } + if resp.ClientType == ihttp.STANDARD { + bl.Host = host + } + bl.Body = resp.Body() bl.BodyLength = resp.ContentLength() bl.Header = resp.Header() @@ -29,11 +32,14 @@ func NewBaseline(u, host string, resp *ihttp.Response) *baseline { func NewInvalidBaseline(u, host string, resp *ihttp.Response) *baseline { bl := &baseline{ Url: u, - Host: host, Status: resp.StatusCode(), IsValid: false, } + if resp.ClientType == ihttp.STANDARD { + bl.Host = host + } + bl.RedirectURL = string(resp.GetHeader("Location")) return bl @@ -80,8 +86,6 @@ func (bl *baseline) Equal(other *baseline) bool { // 如果body length相等且md5相等, 则说明是同一个页面 if bl.BodyMd5 == parsers.Md5Hash(other.Raw) { return true - } else { - return true } } @@ -131,22 +135,24 @@ func (bl *baseline) Get(key string) string { return bl.Frameworks.ToString() default: return "" - } } func (bl *baseline) Additional(key string) string { if v := bl.Get(key); v != "" { - return "[" + v + "]" + return "[" + v + "] " } else { - return "" + return " " } } + func (bl *baseline) String() string { var line strings.Builder //line.WriteString("[+] ") line.WriteString(bl.Url) - line.WriteString(" (" + bl.Host + ")") + if bl.Host != "" { + line.WriteString(" (" + bl.Host + ")") + } line.WriteString(" - ") line.WriteString(strconv.Itoa(bl.Status)) line.WriteString(" - ") diff --git a/internal/pool.go b/internal/pool.go index 10f966b..794147a 100644 --- a/internal/pool.go +++ b/internal/pool.go @@ -86,25 +86,24 @@ func NewPool(ctx context.Context, config *pkg.Config, outputCh chan *baseline) ( pool.failedCount++ bl = &baseline{Url: pool.BaseURL + unit.path, Err: reqerr} } else { - pool.failedCount = 0 + pool.failedCount = 0 // 如果后续访问正常, 重置错误次数 if err = pool.PreCompare(resp); err == nil || unit.source == CheckSource { // 通过预对比跳过一些无用数据, 减少性能消耗 bl = NewBaseline(req.URI(), req.Host(), resp) } else { bl = NewInvalidBaseline(req.URI(), req.Host(), resp) } - bl.Err = reqerr } switch unit.source { + case InitSource: + pool.base = bl + pool.initwg.Done() + return case CheckSource: logs.Log.Debugf("check: " + bl.String()) - if pool.base == nil { - //初次check覆盖baseline - pool.base = bl - pool.initwg.Done() - } else if bl.Err != nil { - logs.Log.Warn("maybe ip banned by waf") + if bl.Err != nil { + logs.Log.Warnf("maybe ip has banned by waf, break (%d/%d), error: %s", pool.failedCount, breakThreshold, bl.Err.Error()) } else if !pool.base.Equal(bl) { logs.Log.Warn("maybe trigger risk control") } @@ -155,12 +154,12 @@ type Pool struct { func (p *Pool) Init() error { p.initwg.Add(1) - p.check() + p.pool.Invoke(newUnit(pkg.RandHost(), InitSource)) p.initwg.Wait() // todo 分析baseline // 检测基本访问能力 - if p.base != nil && p.base.Err != nil { + if p.base.Err != nil { p.cancel() return p.base.Err } @@ -255,10 +254,10 @@ func (p *Pool) buildPathRequest(path string) (*ihttp.Request, error) { if p.Config.ClientType == ihttp.FAST { req := fasthttp.AcquireRequest() req.SetRequestURI(p.BaseURL + path) - return &ihttp.Request{FastRequest: req}, nil + return &ihttp.Request{FastRequest: req, ClientType: p.ClientType}, nil } else { req, err := http.NewRequest("GET", p.BaseURL+path, nil) - return &ihttp.Request{StandardRequest: req}, err + return &ihttp.Request{StandardRequest: req, ClientType: p.ClientType}, err } } @@ -267,10 +266,10 @@ func (p *Pool) buildHostRequest(host string) (*ihttp.Request, error) { req := fasthttp.AcquireRequest() req.SetRequestURI(p.BaseURL) req.SetHost(host) - return &ihttp.Request{FastRequest: req}, nil + return &ihttp.Request{FastRequest: req, ClientType: p.ClientType}, nil } else { req, err := http.NewRequest("GET", p.BaseURL, nil) req.Host = host - return &ihttp.Request{StandardRequest: req}, err + return &ihttp.Request{StandardRequest: req, ClientType: p.ClientType}, err } } diff --git a/internal/types.go b/internal/types.go index 40f58e8..b931664 100644 --- a/internal/types.go +++ b/internal/types.go @@ -25,6 +25,7 @@ type sourceType int const ( CheckSource sourceType = iota + 1 + InitSource WordSource WafSource ) diff --git a/pkg/ihttp/client.go b/pkg/ihttp/client.go index 572fa74..133fef2 100644 --- a/pkg/ihttp/client.go +++ b/pkg/ihttp/client.go @@ -75,10 +75,10 @@ func (c *Client) StandardDo(ctx context.Context, req *http.Request) (*http.Respo func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) { if c.fastClient != nil { resp, err := c.FastDo(ctx, req.FastRequest) - return &Response{FastResponse: resp}, err + return &Response{FastResponse: resp, ClientType: FAST}, err } else if c.standardClient != nil { resp, err := c.StandardDo(ctx, req.StandardRequest) - return &Response{StandardResponse: resp}, err + return &Response{StandardResponse: resp, ClientType: STANDARD}, err } else { return nil, fmt.Errorf("not found client") } diff --git a/pkg/ihttp/request.go b/pkg/ihttp/request.go index cd23c94..89842a9 100644 --- a/pkg/ihttp/request.go +++ b/pkg/ihttp/request.go @@ -8,6 +8,7 @@ import ( type Request struct { StandardRequest *http.Request FastRequest *fasthttp.Request + ClientType int } func (r *Request) URI() string { diff --git a/pkg/ihttp/response.go b/pkg/ihttp/response.go index 6d251d6..420c37c 100644 --- a/pkg/ihttp/response.go +++ b/pkg/ihttp/response.go @@ -11,6 +11,7 @@ import ( type Response struct { StandardResponse *http.Response FastResponse *fasthttp.Response + ClientType int } func (r *Response) StatusCode() int {