实装--check-only参数, 实现类似httpx的批量请求的功能

This commit is contained in:
M09Ic 2022-11-18 20:27:29 +08:00
parent a5966355ae
commit be0fc35cab
8 changed files with 293 additions and 85 deletions

View File

@ -34,5 +34,9 @@ func Spray() {
return
}
runner.Run(ctx)
if runner.CheckOnly {
runner.RunWithCheck(ctx)
} else {
runner.Run(ctx)
}
}

View File

@ -1 +0,0 @@
package internal

123
internal/checkpool.go Normal file
View File

@ -0,0 +1,123 @@
package internal
import (
"context"
"github.com/chainreactors/logs"
"github.com/chainreactors/spray/pkg"
"github.com/chainreactors/spray/pkg/ihttp"
"github.com/chainreactors/words"
"github.com/panjf2000/ants/v2"
"github.com/valyala/fasthttp"
"sync"
)
func NewCheckPool(ctx context.Context, config *pkg.Config) (*CheckPool, error) {
pctx, cancel := context.WithCancel(ctx)
pool := &CheckPool{
Config: config,
ctx: pctx,
cancel: cancel,
client: ihttp.NewClient(config.Thread, 2, config.ClientType),
worder: words.NewWorder(config.Wordlist),
wg: sync.WaitGroup{},
reqCount: 1,
failedCount: 1,
}
switch config.Mod {
case pkg.PathSpray:
pool.genReq = func(s string) (*ihttp.Request, error) {
return ihttp.BuildPathRequest(pool.ClientType, s, "")
}
case pkg.HostSpray:
pool.genReq = func(s string) (*ihttp.Request, error) {
return ihttp.BuildHostRequest(pool.ClientType, s, "")
}
}
p, _ := ants.NewPoolWithFunc(config.Thread, func(i interface{}) {
unit := i.(*Unit)
req, err := pool.genReq(unit.path)
if err != nil {
logs.Log.Error(err.Error())
}
var bl *pkg.Baseline
resp, reqerr := pool.client.Do(pctx, req)
if pool.ClientType == ihttp.FAST {
defer fasthttp.ReleaseResponse(resp.FastResponse)
defer fasthttp.ReleaseRequest(req.FastRequest)
}
if reqerr != nil && reqerr != fasthttp.ErrBodyTooLarge {
pool.failedCount++
bl = &pkg.Baseline{Url: pool.BaseURL + unit.path, IsValid: false, Err: reqerr.Error(), Reason: ErrRequestFailed.Error()}
} else {
bl = pkg.NewBaseline(req.URI(), req.Host(), resp)
bl.Collect()
}
// 异步进行性能消耗较大的深度对比
pool.OutputCh <- bl
pool.reqCount++
pool.wg.Done()
pool.bar.Done()
})
pool.pool = p
return pool, nil
}
type CheckPool struct {
*pkg.Config
client *ihttp.Client
pool *ants.PoolWithFunc
bar *pkg.Bar
ctx context.Context
cancel context.CancelFunc
reqCount int
failedCount int
genReq func(s string) (*ihttp.Request, error)
worder *words.Worder
wg sync.WaitGroup
}
func (p *CheckPool) Close() {
p.bar.Close()
}
func (p *CheckPool) Run(ctx context.Context, offset, limit int) {
Loop:
for {
select {
case u, ok := <-p.worder.C:
if !ok {
break Loop
}
if p.reqCount < offset {
p.reqCount++
continue
}
if p.reqCount > limit {
break Loop
}
for _, fn := range p.Fns {
u = fn(u)
}
if u == "" {
continue
}
p.wg.Add(1)
_ = p.pool.Invoke(newUnit(u, WordSource))
case <-ctx.Done():
break Loop
case <-p.ctx.Done():
break Loop
}
}
p.wg.Wait()
p.Close()
}

View File

@ -61,11 +61,11 @@ type ModeOptions struct {
ErrPeriod int `long:"error-period" default:"10"`
BreakThreshold int `long:"error-threshold" default:"20"`
BlackStatus string `long:"black-status" default:"default"`
WhiteStatus string `long:"black-status" `
WhiteStatus string `long:"white-status" `
}
type MiscOptions struct {
Deadline int `long:"deadline" default:"99999" description:"Int, deadline (seconds)"` // todo 总的超时时间,适配云函数的deadline
Deadline int `long:"deadline" default:"999999" description:"Int, deadline (seconds)"` // todo 总的超时时间,适配云函数的deadline
Timeout int `long:"timeout" default:"2" description:"Int, timeout with request (seconds)"`
PoolSize int `short:"p" long:"pool" default:"5" description:"Int, Pool size"`
Threads int `short:"t" long:"thread" default:"20" description:"Int, number of threads per pool (seconds)"`
@ -90,7 +90,7 @@ func (opt *Option) PrepareRunner() (*Runner, error) {
Deadline: opt.Deadline,
Offset: opt.Offset,
Limit: opt.Limit,
URLList: make(chan string),
URLCh: make(chan string),
OutputCh: make(chan *pkg.Baseline, 100),
FuzzyCh: make(chan *pkg.Baseline, 100),
Fuzzy: opt.Fuzzy,
@ -177,12 +177,16 @@ func (opt *Option) PrepareRunner() (*Runner, error) {
urls[i] = strings.TrimSpace(u)
}
logs.Log.Importantf("load %d urls from %s", len(urls), urlfrom)
go func() {
for _, u := range urls {
r.URLList <- u
}
close(r.URLList)
}()
if !opt.CheckOnly {
go func() {
for _, u := range urls {
r.URLCh <- u
}
close(r.URLCh)
}()
} else {
r.URLList = urls
}
// prepare word
dicts := make([][]string, len(opt.Dictionaries))
@ -222,7 +226,11 @@ func (opt *Option) PrepareRunner() (*Runner, error) {
return nil, err
}
if r.Limit == 0 {
r.Limit = len(r.Wordlist)
if r.CheckOnly {
r.Limit = len(r.URLList)
} else {
r.Limit = len(r.Wordlist)
}
} else {
r.Limit = r.Offset + opt.Limit
}

View File

@ -9,7 +9,6 @@ import (
"github.com/chainreactors/words"
"github.com/panjf2000/ants/v2"
"github.com/valyala/fasthttp"
"net/http"
"sync"
"time"
)
@ -39,7 +38,7 @@ func NewPool(ctx context.Context, config *pkg.Config) (*Pool, error) {
switch config.Mod {
case pkg.PathSpray:
pool.genReq = func(s string) (*ihttp.Request, error) {
return pool.buildPathRequest(s)
return ihttp.BuildPathRequest(pool.ClientType, pool.BaseURL, s)
}
pool.check = func() {
_ = pool.pool.Invoke(newUnit(pkg.RandPath(), CheckSource))
@ -52,7 +51,7 @@ func NewPool(ctx context.Context, config *pkg.Config) (*Pool, error) {
}
case pkg.HostSpray:
pool.genReq = func(s string) (*ihttp.Request, error) {
return pool.buildHostRequest(s)
return ihttp.BuildHostRequest(pool.ClientType, pool.BaseURL, s)
}
pool.check = func() {
@ -247,7 +246,6 @@ Loop:
break Loop
}
}
//p.wg.Add(100)
p.wg.Wait()
p.Close()
}
@ -361,27 +359,3 @@ func (p *Pool) Close() {
close(p.tempCh)
p.bar.Close()
}
func (p *Pool) buildPathRequest(path string) (*ihttp.Request, error) {
if p.Config.ClientType == ihttp.FAST {
req := fasthttp.AcquireRequest()
req.SetRequestURI(p.BaseURL + path)
return &ihttp.Request{FastRequest: req, ClientType: p.ClientType}, nil
} else {
req, err := http.NewRequest("GET", p.BaseURL+path, nil)
return &ihttp.Request{StandardRequest: req, ClientType: p.ClientType}, err
}
}
func (p *Pool) buildHostRequest(host string) (*ihttp.Request, error) {
if p.Config.ClientType == ihttp.FAST {
req := fasthttp.AcquireRequest()
req.SetRequestURI(p.BaseURL)
req.SetHost(host)
return &ihttp.Request{FastRequest: req, ClientType: p.ClientType}, nil
} else {
req, err := http.NewRequest("GET", p.BaseURL, nil)
req.Host = host
return &ihttp.Request{StandardRequest: req, ClientType: p.ClientType}, err
}
}

View File

@ -21,7 +21,8 @@ var (
)
type Runner struct {
URLList chan string
URLCh chan string
URLList []string
Wordlist []string
Headers http.Header
Fns []func(string) string
@ -48,54 +49,74 @@ type Runner struct {
CheckOnly bool
}
func (r *Runner) PrepareConfig() *pkg.Config {
config := &pkg.Config{
Thread: r.Threads,
Timeout: r.Timeout,
Headers: r.Headers,
Mod: pkg.ModMap[r.Mod],
Fns: r.Fns,
OutputCh: r.OutputCh,
FuzzyCh: r.FuzzyCh,
CheckPeriod: r.CheckPeriod,
ErrPeriod: r.ErrPeriod,
BreakThreshold: r.BreakThreshold,
}
if config.Mod == pkg.PathSpray {
config.ClientType = ihttp.FAST
} else if config.Mod == pkg.HostSpray {
config.ClientType = ihttp.STANDARD
}
return config
}
func (r *Runner) Prepare(ctx context.Context) error {
var err error
r.Pools, err = ants.NewPoolWithFunc(r.PoolSize, func(i interface{}) {
u := i.(string)
config := &pkg.Config{
BaseURL: u,
Wordlist: r.Wordlist,
Thread: r.Threads,
Timeout: r.Timeout,
Headers: r.Headers,
Mod: pkg.ModMap[r.Mod],
Fns: r.Fns,
OutputCh: r.OutputCh,
FuzzyCh: r.FuzzyCh,
CheckPeriod: r.CheckPeriod,
ErrPeriod: r.ErrPeriod,
BreakThreshold: r.BreakThreshold,
}
if config.Mod == pkg.PathSpray {
config.ClientType = ihttp.FAST
} else if config.Mod == pkg.HostSpray {
config.ClientType = ihttp.STANDARD
}
pool, err := NewPool(ctx, config)
if err != nil {
logs.Log.Error(err.Error())
pool.cancel()
r.poolwg.Done()
return
}
pool.bar = pkg.NewBar(u, r.Limit-r.Offset, r.Progress)
err = pool.Init()
if err != nil {
logs.Log.Error(err.Error())
if !r.Force {
// 如果没开启force, init失败将会关闭pool
if r.CheckOnly {
r.Pools, err = ants.NewPoolWithFunc(1, func(i interface{}) {
config := r.PrepareConfig()
config.Wordlist = r.URLList
pool, err := NewCheckPool(ctx, config)
if err != nil {
logs.Log.Error(err.Error())
pool.cancel()
r.poolwg.Done()
return
}
}
pool.bar = pkg.NewBar("check", r.Limit-r.Offset, r.Progress)
pool.Run(ctx, r.Offset, r.Limit)
r.poolwg.Done()
})
} else {
r.Pools, err = ants.NewPoolWithFunc(r.PoolSize, func(i interface{}) {
u := i.(string)
config := r.PrepareConfig()
config.BaseURL = u
config.Wordlist = r.Wordlist
pool, err := NewPool(ctx, config)
if err != nil {
logs.Log.Error(err.Error())
pool.cancel()
r.poolwg.Done()
return
}
pool.bar = pkg.NewBar(u, r.Limit-r.Offset, r.Progress)
err = pool.Init()
if err != nil {
logs.Log.Error(err.Error())
if !r.Force {
// 如果没开启force, init失败将会关闭pool
pool.cancel()
r.poolwg.Done()
return
}
}
pool.Run(ctx, r.Offset, r.Limit)
r.poolwg.Done()
})
pool.Run(ctx, r.Offset, r.Limit)
r.poolwg.Done()
})
}
if err != nil {
return err
@ -111,7 +132,7 @@ Loop:
case <-ctx.Done():
logs.Log.Error("cancel with deadline")
break Loop
case u, ok := <-r.URLList:
case u, ok := <-r.URLCh:
if !ok {
break Loop
}
@ -137,6 +158,38 @@ Loop:
time.Sleep(100) // 延迟100ms, 等所有数据处理完毕
}
func (r *Runner) RunWithCheck(ctx context.Context) {
stopCh := make(chan struct{})
r.poolwg.Add(1)
err := r.Pools.Invoke(struct{}{})
if err != nil {
return
}
go func() {
r.poolwg.Wait()
stopCh <- struct{}{}
}()
Loop:
for {
select {
case <-ctx.Done():
logs.Log.Error("cancel with deadline")
break Loop
case <-stopCh:
break Loop
}
}
for {
if len(r.OutputCh) == 0 {
close(r.OutputCh)
break
}
}
time.Sleep(100) // 延迟100ms, 等所有数据处理完毕
}
func (r *Runner) Outputting() {
go func() {
var outFunc func(*pkg.Baseline)

View File

@ -34,7 +34,8 @@ func NewClient(thread int, timeout int, clientType int) *Client {
WriteTimeout: time.Duration(timeout) * time.Second,
MaxResponseBodySize: DefaultMaxBodySize,
},
timeout: time.Duration(timeout) * time.Second,
timeout: time.Duration(timeout) * time.Second,
clientType: clientType,
}
} else {
return &Client{
@ -52,7 +53,8 @@ func NewClient(thread int, timeout int, clientType int) *Client {
Timeout: time.Second * time.Duration(timeout),
CheckRedirect: checkRedirect,
},
timeout: time.Duration(timeout) * time.Second,
timeout: time.Duration(timeout) * time.Second,
clientType: clientType,
}
}
}
@ -60,9 +62,18 @@ func NewClient(thread int, timeout int, clientType int) *Client {
type Client struct {
fastClient *fasthttp.Client
standardClient *http.Client
clientType int
timeout time.Duration
}
func (c *Client) TransToCheck() {
if c.fastClient != nil {
c.fastClient.MaxConnsPerHost = 1
} else if c.standardClient != nil {
}
}
func (c *Client) FastDo(ctx context.Context, req *fasthttp.Request) (*fasthttp.Response, error) {
resp := fasthttp.AcquireResponse()
return resp, c.fastClient.Do(req, resp)

View File

@ -5,12 +5,48 @@ import (
"net/http"
)
func BuildPathRequest(clientType int, base, path string) (*Request, error) {
if clientType == FAST {
req := fasthttp.AcquireRequest()
req.SetRequestURI(base + path)
return &Request{FastRequest: req, ClientType: FAST}, nil
} else {
req, err := http.NewRequest("GET", base+path, nil)
return &Request{StandardRequest: req, ClientType: STANDARD}, err
}
}
func BuildHostRequest(clientType int, base, host string) (*Request, error) {
if clientType == FAST {
req := fasthttp.AcquireRequest()
req.SetRequestURI(base)
req.SetHost(host)
return &Request{FastRequest: req, ClientType: FAST}, nil
} else {
req, err := http.NewRequest("GET", base, nil)
req.Host = host
return &Request{StandardRequest: req, ClientType: STANDARD}, err
}
}
type Request struct {
StandardRequest *http.Request
FastRequest *fasthttp.Request
ClientType int
}
func (r *Request) SetHeader(header map[string]string) {
if r.StandardRequest != nil {
for k, v := range header {
r.StandardRequest.Header.Set(k, v)
}
} else if r.FastRequest != nil {
for k, v := range header {
r.FastRequest.Header.Set(k, v)
}
}
}
func (r *Request) URI() string {
if r.FastRequest != nil {
return r.FastRequest.URI().String()