实装--check-only参数, 实现类似httpx的批量请求的功能

This commit is contained in:
M09Ic 2022-11-18 20:27:29 +08:00
parent a5966355ae
commit be0fc35cab
8 changed files with 293 additions and 85 deletions

View File

@ -34,5 +34,9 @@ func Spray() {
return return
} }
if runner.CheckOnly {
runner.RunWithCheck(ctx)
} else {
runner.Run(ctx) runner.Run(ctx)
}
} }

View File

@ -1 +0,0 @@
package internal

123
internal/checkpool.go Normal file
View File

@ -0,0 +1,123 @@
package internal
import (
"context"
"github.com/chainreactors/logs"
"github.com/chainreactors/spray/pkg"
"github.com/chainreactors/spray/pkg/ihttp"
"github.com/chainreactors/words"
"github.com/panjf2000/ants/v2"
"github.com/valyala/fasthttp"
"sync"
)
func NewCheckPool(ctx context.Context, config *pkg.Config) (*CheckPool, error) {
pctx, cancel := context.WithCancel(ctx)
pool := &CheckPool{
Config: config,
ctx: pctx,
cancel: cancel,
client: ihttp.NewClient(config.Thread, 2, config.ClientType),
worder: words.NewWorder(config.Wordlist),
wg: sync.WaitGroup{},
reqCount: 1,
failedCount: 1,
}
switch config.Mod {
case pkg.PathSpray:
pool.genReq = func(s string) (*ihttp.Request, error) {
return ihttp.BuildPathRequest(pool.ClientType, s, "")
}
case pkg.HostSpray:
pool.genReq = func(s string) (*ihttp.Request, error) {
return ihttp.BuildHostRequest(pool.ClientType, s, "")
}
}
p, _ := ants.NewPoolWithFunc(config.Thread, func(i interface{}) {
unit := i.(*Unit)
req, err := pool.genReq(unit.path)
if err != nil {
logs.Log.Error(err.Error())
}
var bl *pkg.Baseline
resp, reqerr := pool.client.Do(pctx, req)
if pool.ClientType == ihttp.FAST {
defer fasthttp.ReleaseResponse(resp.FastResponse)
defer fasthttp.ReleaseRequest(req.FastRequest)
}
if reqerr != nil && reqerr != fasthttp.ErrBodyTooLarge {
pool.failedCount++
bl = &pkg.Baseline{Url: pool.BaseURL + unit.path, IsValid: false, Err: reqerr.Error(), Reason: ErrRequestFailed.Error()}
} else {
bl = pkg.NewBaseline(req.URI(), req.Host(), resp)
bl.Collect()
}
// 异步进行性能消耗较大的深度对比
pool.OutputCh <- bl
pool.reqCount++
pool.wg.Done()
pool.bar.Done()
})
pool.pool = p
return pool, nil
}
type CheckPool struct {
*pkg.Config
client *ihttp.Client
pool *ants.PoolWithFunc
bar *pkg.Bar
ctx context.Context
cancel context.CancelFunc
reqCount int
failedCount int
genReq func(s string) (*ihttp.Request, error)
worder *words.Worder
wg sync.WaitGroup
}
func (p *CheckPool) Close() {
p.bar.Close()
}
func (p *CheckPool) Run(ctx context.Context, offset, limit int) {
Loop:
for {
select {
case u, ok := <-p.worder.C:
if !ok {
break Loop
}
if p.reqCount < offset {
p.reqCount++
continue
}
if p.reqCount > limit {
break Loop
}
for _, fn := range p.Fns {
u = fn(u)
}
if u == "" {
continue
}
p.wg.Add(1)
_ = p.pool.Invoke(newUnit(u, WordSource))
case <-ctx.Done():
break Loop
case <-p.ctx.Done():
break Loop
}
}
p.wg.Wait()
p.Close()
}

View File

@ -61,11 +61,11 @@ type ModeOptions struct {
ErrPeriod int `long:"error-period" default:"10"` ErrPeriod int `long:"error-period" default:"10"`
BreakThreshold int `long:"error-threshold" default:"20"` BreakThreshold int `long:"error-threshold" default:"20"`
BlackStatus string `long:"black-status" default:"default"` BlackStatus string `long:"black-status" default:"default"`
WhiteStatus string `long:"black-status" ` WhiteStatus string `long:"white-status" `
} }
type MiscOptions struct { type MiscOptions struct {
Deadline int `long:"deadline" default:"99999" description:"Int, deadline (seconds)"` // todo 总的超时时间,适配云函数的deadline Deadline int `long:"deadline" default:"999999" description:"Int, deadline (seconds)"` // todo 总的超时时间,适配云函数的deadline
Timeout int `long:"timeout" default:"2" description:"Int, timeout with request (seconds)"` Timeout int `long:"timeout" default:"2" description:"Int, timeout with request (seconds)"`
PoolSize int `short:"p" long:"pool" default:"5" description:"Int, Pool size"` PoolSize int `short:"p" long:"pool" default:"5" description:"Int, Pool size"`
Threads int `short:"t" long:"thread" default:"20" description:"Int, number of threads per pool (seconds)"` Threads int `short:"t" long:"thread" default:"20" description:"Int, number of threads per pool (seconds)"`
@ -90,7 +90,7 @@ func (opt *Option) PrepareRunner() (*Runner, error) {
Deadline: opt.Deadline, Deadline: opt.Deadline,
Offset: opt.Offset, Offset: opt.Offset,
Limit: opt.Limit, Limit: opt.Limit,
URLList: make(chan string), URLCh: make(chan string),
OutputCh: make(chan *pkg.Baseline, 100), OutputCh: make(chan *pkg.Baseline, 100),
FuzzyCh: make(chan *pkg.Baseline, 100), FuzzyCh: make(chan *pkg.Baseline, 100),
Fuzzy: opt.Fuzzy, Fuzzy: opt.Fuzzy,
@ -177,12 +177,16 @@ func (opt *Option) PrepareRunner() (*Runner, error) {
urls[i] = strings.TrimSpace(u) urls[i] = strings.TrimSpace(u)
} }
logs.Log.Importantf("load %d urls from %s", len(urls), urlfrom) logs.Log.Importantf("load %d urls from %s", len(urls), urlfrom)
if !opt.CheckOnly {
go func() { go func() {
for _, u := range urls { for _, u := range urls {
r.URLList <- u r.URLCh <- u
} }
close(r.URLList) close(r.URLCh)
}() }()
} else {
r.URLList = urls
}
// prepare word // prepare word
dicts := make([][]string, len(opt.Dictionaries)) dicts := make([][]string, len(opt.Dictionaries))
@ -222,7 +226,11 @@ func (opt *Option) PrepareRunner() (*Runner, error) {
return nil, err return nil, err
} }
if r.Limit == 0 { if r.Limit == 0 {
if r.CheckOnly {
r.Limit = len(r.URLList)
} else {
r.Limit = len(r.Wordlist) r.Limit = len(r.Wordlist)
}
} else { } else {
r.Limit = r.Offset + opt.Limit r.Limit = r.Offset + opt.Limit
} }

View File

@ -9,7 +9,6 @@ import (
"github.com/chainreactors/words" "github.com/chainreactors/words"
"github.com/panjf2000/ants/v2" "github.com/panjf2000/ants/v2"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
"net/http"
"sync" "sync"
"time" "time"
) )
@ -39,7 +38,7 @@ func NewPool(ctx context.Context, config *pkg.Config) (*Pool, error) {
switch config.Mod { switch config.Mod {
case pkg.PathSpray: case pkg.PathSpray:
pool.genReq = func(s string) (*ihttp.Request, error) { pool.genReq = func(s string) (*ihttp.Request, error) {
return pool.buildPathRequest(s) return ihttp.BuildPathRequest(pool.ClientType, pool.BaseURL, s)
} }
pool.check = func() { pool.check = func() {
_ = pool.pool.Invoke(newUnit(pkg.RandPath(), CheckSource)) _ = pool.pool.Invoke(newUnit(pkg.RandPath(), CheckSource))
@ -52,7 +51,7 @@ func NewPool(ctx context.Context, config *pkg.Config) (*Pool, error) {
} }
case pkg.HostSpray: case pkg.HostSpray:
pool.genReq = func(s string) (*ihttp.Request, error) { pool.genReq = func(s string) (*ihttp.Request, error) {
return pool.buildHostRequest(s) return ihttp.BuildHostRequest(pool.ClientType, pool.BaseURL, s)
} }
pool.check = func() { pool.check = func() {
@ -247,7 +246,6 @@ Loop:
break Loop break Loop
} }
} }
//p.wg.Add(100)
p.wg.Wait() p.wg.Wait()
p.Close() p.Close()
} }
@ -361,27 +359,3 @@ func (p *Pool) Close() {
close(p.tempCh) close(p.tempCh)
p.bar.Close() p.bar.Close()
} }
func (p *Pool) buildPathRequest(path string) (*ihttp.Request, error) {
if p.Config.ClientType == ihttp.FAST {
req := fasthttp.AcquireRequest()
req.SetRequestURI(p.BaseURL + path)
return &ihttp.Request{FastRequest: req, ClientType: p.ClientType}, nil
} else {
req, err := http.NewRequest("GET", p.BaseURL+path, nil)
return &ihttp.Request{StandardRequest: req, ClientType: p.ClientType}, err
}
}
func (p *Pool) buildHostRequest(host string) (*ihttp.Request, error) {
if p.Config.ClientType == ihttp.FAST {
req := fasthttp.AcquireRequest()
req.SetRequestURI(p.BaseURL)
req.SetHost(host)
return &ihttp.Request{FastRequest: req, ClientType: p.ClientType}, nil
} else {
req, err := http.NewRequest("GET", p.BaseURL, nil)
req.Host = host
return &ihttp.Request{StandardRequest: req, ClientType: p.ClientType}, err
}
}

View File

@ -21,7 +21,8 @@ var (
) )
type Runner struct { type Runner struct {
URLList chan string URLCh chan string
URLList []string
Wordlist []string Wordlist []string
Headers http.Header Headers http.Header
Fns []func(string) string Fns []func(string) string
@ -48,14 +49,8 @@ type Runner struct {
CheckOnly bool CheckOnly bool
} }
func (r *Runner) Prepare(ctx context.Context) error { func (r *Runner) PrepareConfig() *pkg.Config {
var err error
r.Pools, err = ants.NewPoolWithFunc(r.PoolSize, func(i interface{}) {
u := i.(string)
config := &pkg.Config{ config := &pkg.Config{
BaseURL: u,
Wordlist: r.Wordlist,
Thread: r.Threads, Thread: r.Threads,
Timeout: r.Timeout, Timeout: r.Timeout,
Headers: r.Headers, Headers: r.Headers,
@ -67,13 +62,37 @@ func (r *Runner) Prepare(ctx context.Context) error {
ErrPeriod: r.ErrPeriod, ErrPeriod: r.ErrPeriod,
BreakThreshold: r.BreakThreshold, BreakThreshold: r.BreakThreshold,
} }
if config.Mod == pkg.PathSpray { if config.Mod == pkg.PathSpray {
config.ClientType = ihttp.FAST config.ClientType = ihttp.FAST
} else if config.Mod == pkg.HostSpray { } else if config.Mod == pkg.HostSpray {
config.ClientType = ihttp.STANDARD config.ClientType = ihttp.STANDARD
} }
return config
}
func (r *Runner) Prepare(ctx context.Context) error {
var err error
if r.CheckOnly {
r.Pools, err = ants.NewPoolWithFunc(1, func(i interface{}) {
config := r.PrepareConfig()
config.Wordlist = r.URLList
pool, err := NewCheckPool(ctx, config)
if err != nil {
logs.Log.Error(err.Error())
pool.cancel()
r.poolwg.Done()
return
}
pool.bar = pkg.NewBar("check", r.Limit-r.Offset, r.Progress)
pool.Run(ctx, r.Offset, r.Limit)
r.poolwg.Done()
})
} else {
r.Pools, err = ants.NewPoolWithFunc(r.PoolSize, func(i interface{}) {
u := i.(string)
config := r.PrepareConfig()
config.BaseURL = u
config.Wordlist = r.Wordlist
pool, err := NewPool(ctx, config) pool, err := NewPool(ctx, config)
if err != nil { if err != nil {
logs.Log.Error(err.Error()) logs.Log.Error(err.Error())
@ -97,6 +116,8 @@ func (r *Runner) Prepare(ctx context.Context) error {
r.poolwg.Done() r.poolwg.Done()
}) })
}
if err != nil { if err != nil {
return err return err
} }
@ -111,7 +132,7 @@ Loop:
case <-ctx.Done(): case <-ctx.Done():
logs.Log.Error("cancel with deadline") logs.Log.Error("cancel with deadline")
break Loop break Loop
case u, ok := <-r.URLList: case u, ok := <-r.URLCh:
if !ok { if !ok {
break Loop break Loop
} }
@ -137,6 +158,38 @@ Loop:
time.Sleep(100) // 延迟100ms, 等所有数据处理完毕 time.Sleep(100) // 延迟100ms, 等所有数据处理完毕
} }
func (r *Runner) RunWithCheck(ctx context.Context) {
stopCh := make(chan struct{})
r.poolwg.Add(1)
err := r.Pools.Invoke(struct{}{})
if err != nil {
return
}
go func() {
r.poolwg.Wait()
stopCh <- struct{}{}
}()
Loop:
for {
select {
case <-ctx.Done():
logs.Log.Error("cancel with deadline")
break Loop
case <-stopCh:
break Loop
}
}
for {
if len(r.OutputCh) == 0 {
close(r.OutputCh)
break
}
}
time.Sleep(100) // 延迟100ms, 等所有数据处理完毕
}
func (r *Runner) Outputting() { func (r *Runner) Outputting() {
go func() { go func() {
var outFunc func(*pkg.Baseline) var outFunc func(*pkg.Baseline)

View File

@ -35,6 +35,7 @@ func NewClient(thread int, timeout int, clientType int) *Client {
MaxResponseBodySize: DefaultMaxBodySize, MaxResponseBodySize: DefaultMaxBodySize,
}, },
timeout: time.Duration(timeout) * time.Second, timeout: time.Duration(timeout) * time.Second,
clientType: clientType,
} }
} else { } else {
return &Client{ return &Client{
@ -53,6 +54,7 @@ func NewClient(thread int, timeout int, clientType int) *Client {
CheckRedirect: checkRedirect, CheckRedirect: checkRedirect,
}, },
timeout: time.Duration(timeout) * time.Second, timeout: time.Duration(timeout) * time.Second,
clientType: clientType,
} }
} }
} }
@ -60,9 +62,18 @@ func NewClient(thread int, timeout int, clientType int) *Client {
type Client struct { type Client struct {
fastClient *fasthttp.Client fastClient *fasthttp.Client
standardClient *http.Client standardClient *http.Client
clientType int
timeout time.Duration timeout time.Duration
} }
func (c *Client) TransToCheck() {
if c.fastClient != nil {
c.fastClient.MaxConnsPerHost = 1
} else if c.standardClient != nil {
}
}
func (c *Client) FastDo(ctx context.Context, req *fasthttp.Request) (*fasthttp.Response, error) { func (c *Client) FastDo(ctx context.Context, req *fasthttp.Request) (*fasthttp.Response, error) {
resp := fasthttp.AcquireResponse() resp := fasthttp.AcquireResponse()
return resp, c.fastClient.Do(req, resp) return resp, c.fastClient.Do(req, resp)

View File

@ -5,12 +5,48 @@ import (
"net/http" "net/http"
) )
func BuildPathRequest(clientType int, base, path string) (*Request, error) {
if clientType == FAST {
req := fasthttp.AcquireRequest()
req.SetRequestURI(base + path)
return &Request{FastRequest: req, ClientType: FAST}, nil
} else {
req, err := http.NewRequest("GET", base+path, nil)
return &Request{StandardRequest: req, ClientType: STANDARD}, err
}
}
func BuildHostRequest(clientType int, base, host string) (*Request, error) {
if clientType == FAST {
req := fasthttp.AcquireRequest()
req.SetRequestURI(base)
req.SetHost(host)
return &Request{FastRequest: req, ClientType: FAST}, nil
} else {
req, err := http.NewRequest("GET", base, nil)
req.Host = host
return &Request{StandardRequest: req, ClientType: STANDARD}, err
}
}
type Request struct { type Request struct {
StandardRequest *http.Request StandardRequest *http.Request
FastRequest *fasthttp.Request FastRequest *fasthttp.Request
ClientType int ClientType int
} }
func (r *Request) SetHeader(header map[string]string) {
if r.StandardRequest != nil {
for k, v := range header {
r.StandardRequest.Header.Set(k, v)
}
} else if r.FastRequest != nil {
for k, v := range header {
r.FastRequest.Header.Set(k, v)
}
}
}
func (r *Request) URI() string { func (r *Request) URI() string {
if r.FastRequest != nil { if r.FastRequest != nil {
return r.FastRequest.URI().String() return r.FastRequest.URI().String()