This commit is contained in:
M09Ic 2024-09-10 15:41:48 +08:00
parent 5cf02cbbcb
commit 29db702744
7 changed files with 37 additions and 36 deletions

View File

@ -1,7 +1,6 @@
package ihttp package ihttp
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"github.com/chainreactors/logs" "github.com/chainreactors/logs"
@ -56,7 +55,7 @@ func NewClient(config *ClientConfig) *Client {
DisablePathNormalizing: true, DisablePathNormalizing: true,
DisableHeaderNamesNormalizing: true, DisableHeaderNamesNormalizing: true,
}, },
Config: config, ClientConfig: config,
} }
} else { } else {
client = &Client{ client = &Client{
@ -76,7 +75,7 @@ func NewClient(config *ClientConfig) *Client {
return http.ErrUseLastResponse return http.ErrUseLastResponse
}, },
}, },
Config: config, ClientConfig: config,
} }
if config.ProxyAddr != "" { if config.ProxyAddr != "" {
client.standardClient.Transport.(*http.Transport).Proxy = func(_ *http.Request) (*url.URL, error) { client.standardClient.Transport.(*http.Transport).Proxy = func(_ *http.Request) (*url.URL, error) {
@ -97,7 +96,7 @@ type ClientConfig struct {
type Client struct { type Client struct {
fastClient *fasthttp.Client fastClient *fasthttp.Client
standardClient *http.Client standardClient *http.Client
Config *ClientConfig *ClientConfig
} }
func (c *Client) TransToCheck() { func (c *Client) TransToCheck() {
@ -108,22 +107,22 @@ func (c *Client) TransToCheck() {
} }
} }
func (c *Client) FastDo(ctx context.Context, req *fasthttp.Request) (*fasthttp.Response, error) { func (c *Client) FastDo(req *fasthttp.Request) (*fasthttp.Response, error) {
resp := fasthttp.AcquireResponse() resp := fasthttp.AcquireResponse()
err := c.fastClient.Do(req, resp) err := c.fastClient.DoTimeout(req, resp, c.Timeout)
return resp, err return resp, err
} }
func (c *Client) StandardDo(ctx context.Context, req *http.Request) (*http.Response, error) { func (c *Client) StandardDo(req *http.Request) (*http.Response, error) {
return c.standardClient.Do(req) return c.standardClient.Do(req)
} }
func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) { func (c *Client) Do(req *Request) (*Response, error) {
if c.fastClient != nil { if c.fastClient != nil {
resp, err := c.FastDo(ctx, req.FastRequest) resp, err := c.FastDo(req.FastRequest)
return &Response{FastResponse: resp, ClientType: FAST}, err return &Response{FastResponse: resp, ClientType: FAST}, err
} else if c.standardClient != nil { } else if c.standardClient != nil {
resp, err := c.StandardDo(ctx, req.StandardRequest) resp, err := c.StandardDo(req.StandardRequest)
return &Response{StandardResponse: resp, ClientType: STANDARD}, err return &Response{StandardResponse: resp, ClientType: STANDARD}, err
} else { } else {
return nil, fmt.Errorf("not found client") return nil, fmt.Errorf("not found client")

View File

@ -1,11 +1,12 @@
package ihttp package ihttp
import ( import (
"context"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
"net/http" "net/http"
) )
func BuildRequest(clientType int, base, path, host, method string) (*Request, error) { func BuildRequest(ctx context.Context, clientType int, base, path, host, method string) (*Request, error) {
if clientType == FAST { if clientType == FAST {
req := fasthttp.AcquireRequest() req := fasthttp.AcquireRequest()
req.Header.SetMethod(method) req.Header.SetMethod(method)
@ -15,7 +16,7 @@ func BuildRequest(clientType int, base, path, host, method string) (*Request, er
} }
return &Request{FastRequest: req, ClientType: FAST}, nil return &Request{FastRequest: req, ClientType: FAST}, nil
} else { } else {
req, err := http.NewRequest(method, base+path, nil) req, err := http.NewRequestWithContext(ctx, method, base+path, nil)
if host != "" { if host != "" {
req.Host = host req.Host = host
} }

View File

@ -48,7 +48,7 @@ func NewBrutePool(ctx context.Context, config *Config) (*BrutePool, error) {
client: ihttp.NewClient(&ihttp.ClientConfig{ client: ihttp.NewClient(&ihttp.ClientConfig{
Thread: config.Thread, Thread: config.Thread,
Type: config.ClientType, Type: config.ClientType,
Timeout: time.Duration(config.Timeout) * time.Second, Timeout: config.Timeout,
ProxyAddr: config.ProxyAddr, ProxyAddr: config.ProxyAddr,
}), }),
additionCh: make(chan *Unit, config.Thread), additionCh: make(chan *Unit, config.Thread),
@ -167,7 +167,7 @@ func (pool *BrutePool) Init() error {
return nil return nil
} }
func (pool *BrutePool) Run(ctx context.Context, offset, limit int) { func (pool *BrutePool) Run(offset, limit int) {
pool.Worder.Run() pool.Worder.Run()
if pool.Active { if pool.Active {
pool.wg.Add(1) pool.wg.Add(1)
@ -250,8 +250,6 @@ Loop:
} }
case <-pool.closeCh: case <-pool.closeCh:
break Loop break Loop
case <-ctx.Done():
break Loop
case <-pool.ctx.Done(): case <-pool.ctx.Done():
break Loop break Loop
} }
@ -271,7 +269,7 @@ func (pool *BrutePool) Invoke(v interface{}) {
var req *ihttp.Request var req *ihttp.Request
var err error var err error
req, err = ihttp.BuildRequest(pool.ClientType, pool.BaseURL, unit.path, unit.host, pool.Method) req, err = ihttp.BuildRequest(pool.ctx, pool.ClientType, pool.BaseURL, unit.path, unit.host, pool.Method)
if err != nil { if err != nil {
logs.Log.Error(err.Error()) logs.Log.Error(err.Error())
return return
@ -283,7 +281,7 @@ func (pool *BrutePool) Invoke(v interface{}) {
} }
start := time.Now() start := time.Now()
resp, reqerr := pool.client.Do(pool.ctx, req) resp, reqerr := pool.client.Do(req)
if pool.ClientType == ihttp.FAST { if pool.ClientType == ihttp.FAST {
defer fasthttp.ReleaseResponse(resp.FastResponse) defer fasthttp.ReleaseResponse(resp.FastResponse)
defer fasthttp.ReleaseRequest(req.FastRequest) defer fasthttp.ReleaseRequest(req.FastRequest)
@ -397,14 +395,14 @@ func (pool *BrutePool) Invoke(v interface{}) {
func (pool *BrutePool) NoScopeInvoke(v interface{}) { func (pool *BrutePool) NoScopeInvoke(v interface{}) {
defer pool.wg.Done() defer pool.wg.Done()
unit := v.(*Unit) unit := v.(*Unit)
req, err := ihttp.BuildRequest(pool.ClientType, unit.path, "", "", "GET") req, err := ihttp.BuildRequest(pool.ctx, pool.ClientType, unit.path, "", "", "GET")
if err != nil { if err != nil {
logs.Log.Error(err.Error()) logs.Log.Error(err.Error())
return return
} }
req.SetHeaders(pool.Headers) req.SetHeaders(pool.Headers)
req.SetHeader("User-Agent", pkg.RandomUA()) req.SetHeader("User-Agent", pkg.RandomUA())
resp, reqerr := pool.client.Do(pool.ctx, req) resp, reqerr := pool.client.Do(req)
if pool.ClientType == ihttp.FAST { if pool.ClientType == ihttp.FAST {
defer fasthttp.ReleaseResponse(resp.FastResponse) defer fasthttp.ReleaseResponse(resp.FastResponse)
defer fasthttp.ReleaseRequest(req.FastRequest) defer fasthttp.ReleaseRequest(req.FastRequest)
@ -728,7 +726,8 @@ func (pool *BrutePool) Close() {
close(pool.additionCh) // 关闭addition管道 close(pool.additionCh) // 关闭addition管道
close(pool.checkCh) // 关闭check管道 close(pool.checkCh) // 关闭check管道
pool.Statistor.EndTime = time.Now().Unix() pool.Statistor.EndTime = time.Now().Unix()
pool.Bar.Close() pool.reqPool.Release()
pool.scopePool.Release()
} }
func (pool *BrutePool) safePath(u string) string { func (pool *BrutePool) safePath(u string) string {

View File

@ -18,7 +18,7 @@ func NewCheckPool(ctx context.Context, config *Config) (*CheckPool, error) {
pctx, cancel := context.WithCancel(ctx) pctx, cancel := context.WithCancel(ctx)
config.ClientType = ihttp.STANDARD config.ClientType = ihttp.STANDARD
pool := &CheckPool{ pool := &CheckPool{
&BasePool{ BasePool: &BasePool{
Config: config, Config: config,
Statistor: pkg.NewStatistor(""), Statistor: pkg.NewStatistor(""),
ctx: pctx, ctx: pctx,
@ -38,13 +38,14 @@ func NewCheckPool(ctx context.Context, config *Config) (*CheckPool, error) {
pool.Headers = map[string]string{"Connection": "close"} pool.Headers = map[string]string{"Connection": "close"}
p, _ := ants.NewPoolWithFunc(config.Thread, pool.Invoke) p, _ := ants.NewPoolWithFunc(config.Thread, pool.Invoke)
pool.BasePool.Pool = p pool.Pool = p
go pool.Handler() go pool.Handler()
return pool, nil return pool, nil
} }
type CheckPool struct { type CheckPool struct {
*BasePool *BasePool
Pool *ants.PoolWithFunc
} }
func (pool *CheckPool) Run(ctx context.Context, offset, limit int) { func (pool *CheckPool) Run(ctx context.Context, offset, limit int) {
@ -82,12 +83,12 @@ Loop:
} }
pool.wg.Add(1) pool.wg.Add(1)
_ = pool.BasePool.Pool.Invoke(newUnit(u, parsers.CheckSource)) _ = pool.Pool.Invoke(newUnit(u, parsers.CheckSource))
case u, ok := <-pool.additionCh: case u, ok := <-pool.additionCh:
if !ok { if !ok {
continue continue
} }
_ = pool.BasePool.Pool.Invoke(u) _ = pool.Pool.Invoke(u)
case <-pool.closeCh: case <-pool.closeCh:
break Loop break Loop
case <-ctx.Done(): case <-ctx.Done():
@ -99,6 +100,10 @@ Loop:
pool.Close() pool.Close()
} }
func (pool *CheckPool) Close() {
pool.Bar.Close()
pool.Pool.Release()
}
func (pool *CheckPool) Invoke(v interface{}) { func (pool *CheckPool) Invoke(v interface{}) {
defer func() { defer func() {
@ -107,7 +112,7 @@ func (pool *CheckPool) Invoke(v interface{}) {
}() }()
unit := v.(*Unit) unit := v.(*Unit)
req, err := ihttp.BuildRequest(pool.ClientType, unit.path, "", "", "GET") req, err := ihttp.BuildRequest(pool.ctx, pool.ClientType, unit.path, "", "", "GET")
if err != nil { if err != nil {
logs.Log.Debug(err.Error()) logs.Log.Debug(err.Error())
bl := &pkg.Baseline{ bl := &pkg.Baseline{
@ -125,7 +130,7 @@ func (pool *CheckPool) Invoke(v interface{}) {
req.SetHeaders(pool.Headers) req.SetHeaders(pool.Headers)
start := time.Now() start := time.Now()
var bl *pkg.Baseline var bl *pkg.Baseline
resp, reqerr := pool.client.Do(pool.ctx, req) resp, reqerr := pool.client.Do(req)
if reqerr != nil { if reqerr != nil {
pool.failedCount++ pool.failedCount++
bl = &pkg.Baseline{ bl = &pkg.Baseline{

View File

@ -5,6 +5,7 @@ import (
"github.com/chainreactors/words/rule" "github.com/chainreactors/words/rule"
"github.com/expr-lang/expr/vm" "github.com/expr-lang/expr/vm"
"sync" "sync"
"time"
) )
type SprayMod int type SprayMod int
@ -26,7 +27,7 @@ type Config struct {
ProxyAddr string ProxyAddr string
Thread int Thread int
Wordlist []string Wordlist []string
Timeout int Timeout time.Duration
ProcessCh chan *pkg.Baseline ProcessCh chan *pkg.Baseline
OutputCh chan *pkg.Baseline OutputCh chan *pkg.Baseline
FuzzyCh chan *pkg.Baseline FuzzyCh chan *pkg.Baseline

View File

@ -6,14 +6,12 @@ import (
"github.com/chainreactors/spray/internal/ihttp" "github.com/chainreactors/spray/internal/ihttp"
"github.com/chainreactors/spray/pkg" "github.com/chainreactors/spray/pkg"
"github.com/chainreactors/words" "github.com/chainreactors/words"
"github.com/panjf2000/ants/v2"
"sync" "sync"
) )
type BasePool struct { type BasePool struct {
*Config *Config
Statistor *pkg.Statistor Statistor *pkg.Statistor
Pool *ants.PoolWithFunc
Bar *pkg.Bar Bar *pkg.Bar
Worder *words.Worder Worder *words.Worder
Cancel context.CancelFunc Cancel context.CancelFunc
@ -72,10 +70,6 @@ func (pool *BasePool) addAddition(u *Unit) {
pool.additionCh <- u pool.additionCh <- u
} }
func (pool *BasePool) Close() {
pool.Bar.Close()
}
func (pool *BasePool) putToOutput(bl *pkg.Baseline) { func (pool *BasePool) putToOutput(bl *pkg.Baseline) {
if bl.IsValid || bl.IsFuzzy { if bl.IsValid || bl.IsFuzzy {
bl.Collect() bl.Collect()

View File

@ -15,6 +15,7 @@ import (
"github.com/vbauerster/mpb/v8/decor" "github.com/vbauerster/mpb/v8/decor"
"strings" "strings"
"sync" "sync"
"time"
) )
var ( var (
@ -61,7 +62,7 @@ type Runner struct {
func (r *Runner) PrepareConfig() *pool.Config { func (r *Runner) PrepareConfig() *pool.Config {
config := &pool.Config{ config := &pool.Config{
Thread: r.Threads, Thread: r.Threads,
Timeout: r.Timeout, Timeout: time.Duration(r.Timeout) * time.Second,
RateLimit: r.RateLimit, RateLimit: r.RateLimit,
Headers: r.Headers, Headers: r.Headers,
Method: r.Method, Method: r.Method,
@ -206,7 +207,7 @@ func (r *Runner) Prepare(ctx context.Context) error {
} }
} }
brutePool.Run(ctx, brutePool.Statistor.Offset, limit) brutePool.Run(brutePool.Statistor.Offset, limit)
if brutePool.IsFailed && len(brutePool.FailedBaselines) > 0 { if brutePool.IsFailed && len(brutePool.FailedBaselines) > 0 {
// 如果因为错误积累退出, end将指向第一个错误发生时, 防止resume时跳过大量目标 // 如果因为错误积累退出, end将指向第一个错误发生时, 防止resume时跳过大量目标
@ -229,6 +230,7 @@ Loop:
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
// 如果超过了deadline, 尚未开始的任务都将被记录到stat中
if len(r.taskCh) > 0 { if len(r.taskCh) > 0 {
for t := range r.taskCh { for t := range r.taskCh {
stat := pkg.NewStatistor(t.baseUrl) stat := pkg.NewStatistor(t.baseUrl)