diff --git a/internal/ihttp/client.go b/internal/ihttp/client.go index 672ff92..4ee83e2 100644 --- a/internal/ihttp/client.go +++ b/internal/ihttp/client.go @@ -1,7 +1,6 @@ package ihttp import ( - "context" "crypto/tls" "fmt" "github.com/chainreactors/logs" @@ -56,7 +55,7 @@ func NewClient(config *ClientConfig) *Client { DisablePathNormalizing: true, DisableHeaderNamesNormalizing: true, }, - Config: config, + ClientConfig: config, } } else { client = &Client{ @@ -76,7 +75,7 @@ func NewClient(config *ClientConfig) *Client { return http.ErrUseLastResponse }, }, - Config: config, + ClientConfig: config, } if config.ProxyAddr != "" { client.standardClient.Transport.(*http.Transport).Proxy = func(_ *http.Request) (*url.URL, error) { @@ -97,7 +96,7 @@ type ClientConfig struct { type Client struct { fastClient *fasthttp.Client standardClient *http.Client - Config *ClientConfig + *ClientConfig } func (c *Client) TransToCheck() { @@ -108,22 +107,22 @@ func (c *Client) TransToCheck() { } } -func (c *Client) FastDo(ctx context.Context, req *fasthttp.Request) (*fasthttp.Response, error) { +func (c *Client) FastDo(req *fasthttp.Request) (*fasthttp.Response, error) { resp := fasthttp.AcquireResponse() - err := c.fastClient.Do(req, resp) + err := c.fastClient.DoTimeout(req, resp, c.Timeout) return resp, err } -func (c *Client) StandardDo(ctx context.Context, req *http.Request) (*http.Response, error) { +func (c *Client) StandardDo(req *http.Request) (*http.Response, error) { return c.standardClient.Do(req) } -func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) { +func (c *Client) Do(req *Request) (*Response, error) { if c.fastClient != nil { - resp, err := c.FastDo(ctx, req.FastRequest) + resp, err := c.FastDo(req.FastRequest) return &Response{FastResponse: resp, ClientType: FAST}, err } else if c.standardClient != nil { - resp, err := c.StandardDo(ctx, req.StandardRequest) + resp, err := c.StandardDo(req.StandardRequest) return &Response{StandardResponse: resp, ClientType: STANDARD}, err } else { return nil, fmt.Errorf("not found client") diff --git a/internal/ihttp/request.go b/internal/ihttp/request.go index e99e65c..62d7f81 100644 --- a/internal/ihttp/request.go +++ b/internal/ihttp/request.go @@ -1,11 +1,12 @@ package ihttp import ( + "context" "github.com/valyala/fasthttp" "net/http" ) -func BuildRequest(clientType int, base, path, host, method string) (*Request, error) { +func BuildRequest(ctx context.Context, clientType int, base, path, host, method string) (*Request, error) { if clientType == FAST { req := fasthttp.AcquireRequest() req.Header.SetMethod(method) @@ -15,7 +16,7 @@ func BuildRequest(clientType int, base, path, host, method string) (*Request, er } return &Request{FastRequest: req, ClientType: FAST}, nil } else { - req, err := http.NewRequest(method, base+path, nil) + req, err := http.NewRequestWithContext(ctx, method, base+path, nil) if host != "" { req.Host = host } diff --git a/internal/pool/brutepool.go b/internal/pool/brutepool.go index 3970b3d..dad0c4a 100644 --- a/internal/pool/brutepool.go +++ b/internal/pool/brutepool.go @@ -48,7 +48,7 @@ func NewBrutePool(ctx context.Context, config *Config) (*BrutePool, error) { client: ihttp.NewClient(&ihttp.ClientConfig{ Thread: config.Thread, Type: config.ClientType, - Timeout: time.Duration(config.Timeout) * time.Second, + Timeout: config.Timeout, ProxyAddr: config.ProxyAddr, }), additionCh: make(chan *Unit, config.Thread), @@ -167,7 +167,7 @@ func (pool *BrutePool) Init() error { return nil } -func (pool *BrutePool) Run(ctx context.Context, offset, limit int) { +func (pool *BrutePool) Run(offset, limit int) { pool.Worder.Run() if pool.Active { pool.wg.Add(1) @@ -250,8 +250,6 @@ Loop: } case <-pool.closeCh: break Loop - case <-ctx.Done(): - break Loop case <-pool.ctx.Done(): break Loop } @@ -271,7 +269,7 @@ func (pool *BrutePool) Invoke(v interface{}) { var req *ihttp.Request var err error - req, err = ihttp.BuildRequest(pool.ClientType, pool.BaseURL, unit.path, unit.host, pool.Method) + req, err = ihttp.BuildRequest(pool.ctx, pool.ClientType, pool.BaseURL, unit.path, unit.host, pool.Method) if err != nil { logs.Log.Error(err.Error()) return @@ -283,7 +281,7 @@ func (pool *BrutePool) Invoke(v interface{}) { } start := time.Now() - resp, reqerr := pool.client.Do(pool.ctx, req) + resp, reqerr := pool.client.Do(req) if pool.ClientType == ihttp.FAST { defer fasthttp.ReleaseResponse(resp.FastResponse) defer fasthttp.ReleaseRequest(req.FastRequest) @@ -397,14 +395,14 @@ func (pool *BrutePool) Invoke(v interface{}) { func (pool *BrutePool) NoScopeInvoke(v interface{}) { defer pool.wg.Done() unit := v.(*Unit) - req, err := ihttp.BuildRequest(pool.ClientType, unit.path, "", "", "GET") + req, err := ihttp.BuildRequest(pool.ctx, pool.ClientType, unit.path, "", "", "GET") if err != nil { logs.Log.Error(err.Error()) return } req.SetHeaders(pool.Headers) req.SetHeader("User-Agent", pkg.RandomUA()) - resp, reqerr := pool.client.Do(pool.ctx, req) + resp, reqerr := pool.client.Do(req) if pool.ClientType == ihttp.FAST { defer fasthttp.ReleaseResponse(resp.FastResponse) defer fasthttp.ReleaseRequest(req.FastRequest) @@ -728,7 +726,8 @@ func (pool *BrutePool) Close() { close(pool.additionCh) // 关闭addition管道 close(pool.checkCh) // 关闭check管道 pool.Statistor.EndTime = time.Now().Unix() - pool.Bar.Close() + pool.reqPool.Release() + pool.scopePool.Release() } func (pool *BrutePool) safePath(u string) string { diff --git a/internal/pool/checkpool.go b/internal/pool/checkpool.go index 55a5032..a1cbb8b 100644 --- a/internal/pool/checkpool.go +++ b/internal/pool/checkpool.go @@ -18,7 +18,7 @@ func NewCheckPool(ctx context.Context, config *Config) (*CheckPool, error) { pctx, cancel := context.WithCancel(ctx) config.ClientType = ihttp.STANDARD pool := &CheckPool{ - &BasePool{ + BasePool: &BasePool{ Config: config, Statistor: pkg.NewStatistor(""), ctx: pctx, @@ -38,13 +38,14 @@ func NewCheckPool(ctx context.Context, config *Config) (*CheckPool, error) { pool.Headers = map[string]string{"Connection": "close"} p, _ := ants.NewPoolWithFunc(config.Thread, pool.Invoke) - pool.BasePool.Pool = p + pool.Pool = p go pool.Handler() return pool, nil } type CheckPool struct { *BasePool + Pool *ants.PoolWithFunc } func (pool *CheckPool) Run(ctx context.Context, offset, limit int) { @@ -82,12 +83,12 @@ Loop: } pool.wg.Add(1) - _ = pool.BasePool.Pool.Invoke(newUnit(u, parsers.CheckSource)) + _ = pool.Pool.Invoke(newUnit(u, parsers.CheckSource)) case u, ok := <-pool.additionCh: if !ok { continue } - _ = pool.BasePool.Pool.Invoke(u) + _ = pool.Pool.Invoke(u) case <-pool.closeCh: break Loop case <-ctx.Done(): @@ -99,6 +100,10 @@ Loop: pool.Close() } +func (pool *CheckPool) Close() { + pool.Bar.Close() + pool.Pool.Release() +} func (pool *CheckPool) Invoke(v interface{}) { defer func() { @@ -107,7 +112,7 @@ func (pool *CheckPool) Invoke(v interface{}) { }() unit := v.(*Unit) - req, err := ihttp.BuildRequest(pool.ClientType, unit.path, "", "", "GET") + req, err := ihttp.BuildRequest(pool.ctx, pool.ClientType, unit.path, "", "", "GET") if err != nil { logs.Log.Debug(err.Error()) bl := &pkg.Baseline{ @@ -125,7 +130,7 @@ func (pool *CheckPool) Invoke(v interface{}) { req.SetHeaders(pool.Headers) start := time.Now() var bl *pkg.Baseline - resp, reqerr := pool.client.Do(pool.ctx, req) + resp, reqerr := pool.client.Do(req) if reqerr != nil { pool.failedCount++ bl = &pkg.Baseline{ diff --git a/internal/pool/config.go b/internal/pool/config.go index f3cae85..0857b49 100644 --- a/internal/pool/config.go +++ b/internal/pool/config.go @@ -5,6 +5,7 @@ import ( "github.com/chainreactors/words/rule" "github.com/expr-lang/expr/vm" "sync" + "time" ) type SprayMod int @@ -26,7 +27,7 @@ type Config struct { ProxyAddr string Thread int Wordlist []string - Timeout int + Timeout time.Duration ProcessCh chan *pkg.Baseline OutputCh chan *pkg.Baseline FuzzyCh chan *pkg.Baseline diff --git a/internal/pool/pool.go b/internal/pool/pool.go index bed4ebf..7709a58 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -6,14 +6,12 @@ import ( "github.com/chainreactors/spray/internal/ihttp" "github.com/chainreactors/spray/pkg" "github.com/chainreactors/words" - "github.com/panjf2000/ants/v2" "sync" ) type BasePool struct { *Config Statistor *pkg.Statistor - Pool *ants.PoolWithFunc Bar *pkg.Bar Worder *words.Worder Cancel context.CancelFunc @@ -72,10 +70,6 @@ func (pool *BasePool) addAddition(u *Unit) { pool.additionCh <- u } -func (pool *BasePool) Close() { - pool.Bar.Close() -} - func (pool *BasePool) putToOutput(bl *pkg.Baseline) { if bl.IsValid || bl.IsFuzzy { bl.Collect() diff --git a/internal/runner.go b/internal/runner.go index 15424a5..eeee03a 100644 --- a/internal/runner.go +++ b/internal/runner.go @@ -15,6 +15,7 @@ import ( "github.com/vbauerster/mpb/v8/decor" "strings" "sync" + "time" ) var ( @@ -61,7 +62,7 @@ type Runner struct { func (r *Runner) PrepareConfig() *pool.Config { config := &pool.Config{ Thread: r.Threads, - Timeout: r.Timeout, + Timeout: time.Duration(r.Timeout) * time.Second, RateLimit: r.RateLimit, Headers: r.Headers, Method: r.Method, @@ -206,7 +207,7 @@ func (r *Runner) Prepare(ctx context.Context) error { } } - brutePool.Run(ctx, brutePool.Statistor.Offset, limit) + brutePool.Run(brutePool.Statistor.Offset, limit) if brutePool.IsFailed && len(brutePool.FailedBaselines) > 0 { // 如果因为错误积累退出, end将指向第一个错误发生时, 防止resume时跳过大量目标 @@ -229,6 +230,7 @@ Loop: for { select { case <-ctx.Done(): + // 如果超过了deadline, 尚未开始的任务都将被记录到stat中 if len(r.taskCh) > 0 { for t := range r.taskCh { stat := pkg.NewStatistor(t.baseUrl)