This commit is contained in:
M09Ic 2024-09-10 15:41:48 +08:00
parent 5cf02cbbcb
commit 29db702744
7 changed files with 37 additions and 36 deletions

View File

@ -1,7 +1,6 @@
package ihttp
import (
"context"
"crypto/tls"
"fmt"
"github.com/chainreactors/logs"
@ -56,7 +55,7 @@ func NewClient(config *ClientConfig) *Client {
DisablePathNormalizing: true,
DisableHeaderNamesNormalizing: true,
},
Config: config,
ClientConfig: config,
}
} else {
client = &Client{
@ -76,7 +75,7 @@ func NewClient(config *ClientConfig) *Client {
return http.ErrUseLastResponse
},
},
Config: config,
ClientConfig: config,
}
if config.ProxyAddr != "" {
client.standardClient.Transport.(*http.Transport).Proxy = func(_ *http.Request) (*url.URL, error) {
@ -97,7 +96,7 @@ type ClientConfig struct {
type Client struct {
fastClient *fasthttp.Client
standardClient *http.Client
Config *ClientConfig
*ClientConfig
}
func (c *Client) TransToCheck() {
@ -108,22 +107,22 @@ func (c *Client) TransToCheck() {
}
}
func (c *Client) FastDo(ctx context.Context, req *fasthttp.Request) (*fasthttp.Response, error) {
func (c *Client) FastDo(req *fasthttp.Request) (*fasthttp.Response, error) {
resp := fasthttp.AcquireResponse()
err := c.fastClient.Do(req, resp)
err := c.fastClient.DoTimeout(req, resp, c.Timeout)
return resp, err
}
func (c *Client) StandardDo(ctx context.Context, req *http.Request) (*http.Response, error) {
func (c *Client) StandardDo(req *http.Request) (*http.Response, error) {
return c.standardClient.Do(req)
}
func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) {
func (c *Client) Do(req *Request) (*Response, error) {
if c.fastClient != nil {
resp, err := c.FastDo(ctx, req.FastRequest)
resp, err := c.FastDo(req.FastRequest)
return &Response{FastResponse: resp, ClientType: FAST}, err
} else if c.standardClient != nil {
resp, err := c.StandardDo(ctx, req.StandardRequest)
resp, err := c.StandardDo(req.StandardRequest)
return &Response{StandardResponse: resp, ClientType: STANDARD}, err
} else {
return nil, fmt.Errorf("not found client")

View File

@ -1,11 +1,12 @@
package ihttp
import (
"context"
"github.com/valyala/fasthttp"
"net/http"
)
func BuildRequest(clientType int, base, path, host, method string) (*Request, error) {
func BuildRequest(ctx context.Context, clientType int, base, path, host, method string) (*Request, error) {
if clientType == FAST {
req := fasthttp.AcquireRequest()
req.Header.SetMethod(method)
@ -15,7 +16,7 @@ func BuildRequest(clientType int, base, path, host, method string) (*Request, er
}
return &Request{FastRequest: req, ClientType: FAST}, nil
} else {
req, err := http.NewRequest(method, base+path, nil)
req, err := http.NewRequestWithContext(ctx, method, base+path, nil)
if host != "" {
req.Host = host
}

View File

@ -48,7 +48,7 @@ func NewBrutePool(ctx context.Context, config *Config) (*BrutePool, error) {
client: ihttp.NewClient(&ihttp.ClientConfig{
Thread: config.Thread,
Type: config.ClientType,
Timeout: time.Duration(config.Timeout) * time.Second,
Timeout: config.Timeout,
ProxyAddr: config.ProxyAddr,
}),
additionCh: make(chan *Unit, config.Thread),
@ -167,7 +167,7 @@ func (pool *BrutePool) Init() error {
return nil
}
func (pool *BrutePool) Run(ctx context.Context, offset, limit int) {
func (pool *BrutePool) Run(offset, limit int) {
pool.Worder.Run()
if pool.Active {
pool.wg.Add(1)
@ -250,8 +250,6 @@ Loop:
}
case <-pool.closeCh:
break Loop
case <-ctx.Done():
break Loop
case <-pool.ctx.Done():
break Loop
}
@ -271,7 +269,7 @@ func (pool *BrutePool) Invoke(v interface{}) {
var req *ihttp.Request
var err error
req, err = ihttp.BuildRequest(pool.ClientType, pool.BaseURL, unit.path, unit.host, pool.Method)
req, err = ihttp.BuildRequest(pool.ctx, pool.ClientType, pool.BaseURL, unit.path, unit.host, pool.Method)
if err != nil {
logs.Log.Error(err.Error())
return
@ -283,7 +281,7 @@ func (pool *BrutePool) Invoke(v interface{}) {
}
start := time.Now()
resp, reqerr := pool.client.Do(pool.ctx, req)
resp, reqerr := pool.client.Do(req)
if pool.ClientType == ihttp.FAST {
defer fasthttp.ReleaseResponse(resp.FastResponse)
defer fasthttp.ReleaseRequest(req.FastRequest)
@ -397,14 +395,14 @@ func (pool *BrutePool) Invoke(v interface{}) {
func (pool *BrutePool) NoScopeInvoke(v interface{}) {
defer pool.wg.Done()
unit := v.(*Unit)
req, err := ihttp.BuildRequest(pool.ClientType, unit.path, "", "", "GET")
req, err := ihttp.BuildRequest(pool.ctx, pool.ClientType, unit.path, "", "", "GET")
if err != nil {
logs.Log.Error(err.Error())
return
}
req.SetHeaders(pool.Headers)
req.SetHeader("User-Agent", pkg.RandomUA())
resp, reqerr := pool.client.Do(pool.ctx, req)
resp, reqerr := pool.client.Do(req)
if pool.ClientType == ihttp.FAST {
defer fasthttp.ReleaseResponse(resp.FastResponse)
defer fasthttp.ReleaseRequest(req.FastRequest)
@ -728,7 +726,8 @@ func (pool *BrutePool) Close() {
close(pool.additionCh) // 关闭addition管道
close(pool.checkCh) // 关闭check管道
pool.Statistor.EndTime = time.Now().Unix()
pool.Bar.Close()
pool.reqPool.Release()
pool.scopePool.Release()
}
func (pool *BrutePool) safePath(u string) string {

View File

@ -18,7 +18,7 @@ func NewCheckPool(ctx context.Context, config *Config) (*CheckPool, error) {
pctx, cancel := context.WithCancel(ctx)
config.ClientType = ihttp.STANDARD
pool := &CheckPool{
&BasePool{
BasePool: &BasePool{
Config: config,
Statistor: pkg.NewStatistor(""),
ctx: pctx,
@ -38,13 +38,14 @@ func NewCheckPool(ctx context.Context, config *Config) (*CheckPool, error) {
pool.Headers = map[string]string{"Connection": "close"}
p, _ := ants.NewPoolWithFunc(config.Thread, pool.Invoke)
pool.BasePool.Pool = p
pool.Pool = p
go pool.Handler()
return pool, nil
}
type CheckPool struct {
*BasePool
Pool *ants.PoolWithFunc
}
func (pool *CheckPool) Run(ctx context.Context, offset, limit int) {
@ -82,12 +83,12 @@ Loop:
}
pool.wg.Add(1)
_ = pool.BasePool.Pool.Invoke(newUnit(u, parsers.CheckSource))
_ = pool.Pool.Invoke(newUnit(u, parsers.CheckSource))
case u, ok := <-pool.additionCh:
if !ok {
continue
}
_ = pool.BasePool.Pool.Invoke(u)
_ = pool.Pool.Invoke(u)
case <-pool.closeCh:
break Loop
case <-ctx.Done():
@ -99,6 +100,10 @@ Loop:
pool.Close()
}
func (pool *CheckPool) Close() {
pool.Bar.Close()
pool.Pool.Release()
}
func (pool *CheckPool) Invoke(v interface{}) {
defer func() {
@ -107,7 +112,7 @@ func (pool *CheckPool) Invoke(v interface{}) {
}()
unit := v.(*Unit)
req, err := ihttp.BuildRequest(pool.ClientType, unit.path, "", "", "GET")
req, err := ihttp.BuildRequest(pool.ctx, pool.ClientType, unit.path, "", "", "GET")
if err != nil {
logs.Log.Debug(err.Error())
bl := &pkg.Baseline{
@ -125,7 +130,7 @@ func (pool *CheckPool) Invoke(v interface{}) {
req.SetHeaders(pool.Headers)
start := time.Now()
var bl *pkg.Baseline
resp, reqerr := pool.client.Do(pool.ctx, req)
resp, reqerr := pool.client.Do(req)
if reqerr != nil {
pool.failedCount++
bl = &pkg.Baseline{

View File

@ -5,6 +5,7 @@ import (
"github.com/chainreactors/words/rule"
"github.com/expr-lang/expr/vm"
"sync"
"time"
)
type SprayMod int
@ -26,7 +27,7 @@ type Config struct {
ProxyAddr string
Thread int
Wordlist []string
Timeout int
Timeout time.Duration
ProcessCh chan *pkg.Baseline
OutputCh chan *pkg.Baseline
FuzzyCh chan *pkg.Baseline

View File

@ -6,14 +6,12 @@ import (
"github.com/chainreactors/spray/internal/ihttp"
"github.com/chainreactors/spray/pkg"
"github.com/chainreactors/words"
"github.com/panjf2000/ants/v2"
"sync"
)
type BasePool struct {
*Config
Statistor *pkg.Statistor
Pool *ants.PoolWithFunc
Bar *pkg.Bar
Worder *words.Worder
Cancel context.CancelFunc
@ -72,10 +70,6 @@ func (pool *BasePool) addAddition(u *Unit) {
pool.additionCh <- u
}
func (pool *BasePool) Close() {
pool.Bar.Close()
}
func (pool *BasePool) putToOutput(bl *pkg.Baseline) {
if bl.IsValid || bl.IsFuzzy {
bl.Collect()

View File

@ -15,6 +15,7 @@ import (
"github.com/vbauerster/mpb/v8/decor"
"strings"
"sync"
"time"
)
var (
@ -61,7 +62,7 @@ type Runner struct {
func (r *Runner) PrepareConfig() *pool.Config {
config := &pool.Config{
Thread: r.Threads,
Timeout: r.Timeout,
Timeout: time.Duration(r.Timeout) * time.Second,
RateLimit: r.RateLimit,
Headers: r.Headers,
Method: r.Method,
@ -206,7 +207,7 @@ func (r *Runner) Prepare(ctx context.Context) error {
}
}
brutePool.Run(ctx, brutePool.Statistor.Offset, limit)
brutePool.Run(brutePool.Statistor.Offset, limit)
if brutePool.IsFailed && len(brutePool.FailedBaselines) > 0 {
// 如果因为错误积累退出, end将指向第一个错误发生时, 防止resume时跳过大量目标
@ -229,6 +230,7 @@ Loop:
for {
select {
case <-ctx.Done():
// 如果超过了deadline, 尚未开始的任务都将被记录到stat中
if len(r.taskCh) > 0 {
for t := range r.taskCh {
stat := pkg.NewStatistor(t.baseUrl)