spray/internal/pool.go

375 lines
8.8 KiB
Go
Raw Normal View History

2022-09-08 15:57:17 +08:00
package internal
import (
"context"
"fmt"
"github.com/chainreactors/logs"
"github.com/chainreactors/spray/pkg"
"github.com/chainreactors/spray/pkg/ihttp"
2022-09-15 19:27:07 +08:00
"github.com/chainreactors/words"
2022-09-08 15:57:17 +08:00
"github.com/panjf2000/ants/v2"
"github.com/valyala/fasthttp"
"net/http"
2022-09-08 15:57:17 +08:00
"sync"
2022-09-15 19:27:07 +08:00
"time"
2022-09-08 15:57:17 +08:00
)
var (
CheckStatusCode func(int) bool
CheckRedirect func(string) bool
2022-09-26 17:19:08 +08:00
CheckWaf func([]byte) bool
2022-09-08 15:57:17 +08:00
)
var breakThreshold int = 20
2022-09-23 01:47:24 +08:00
func NewPool(ctx context.Context, config *pkg.Config) (*Pool, error) {
2022-09-19 14:42:29 +08:00
pctx, cancel := context.WithCancel(ctx)
2022-09-08 15:57:17 +08:00
pool := &Pool{
Config: config,
ctx: pctx,
2022-09-23 01:47:24 +08:00
cancel: cancel,
client: ihttp.NewClient(config.Thread, 2, config.ClientType),
worder: words.NewWorder(config.Wordlist),
baselines: make(map[int]*pkg.Baseline),
tempCh: make(chan *pkg.Baseline, config.Thread),
2022-09-23 11:20:41 +08:00
wg: sync.WaitGroup{},
initwg: sync.WaitGroup{},
checkPeriod: 100,
errPeriod: 10,
reqCount: 1,
failedCount: 1,
2022-09-08 15:57:17 +08:00
}
switch config.Mod {
case pkg.PathSpray:
pool.genReq = func(s string) (*ihttp.Request, error) {
2022-09-23 11:20:41 +08:00
return pool.buildPathRequest(s)
2022-09-08 15:57:17 +08:00
}
pool.check = func() {
pool.wg.Add(1)
_ = pool.pool.Invoke(newUnit(pkg.RandPath(), CheckSource))
if pool.failedCount > breakThreshold {
// 当报错次数超过上限是, 结束任务
pool.recover()
pool.cancel()
}
}
2022-09-08 15:57:17 +08:00
case pkg.HostSpray:
pool.genReq = func(s string) (*ihttp.Request, error) {
2022-09-23 11:20:41 +08:00
return pool.buildHostRequest(s)
2022-09-08 15:57:17 +08:00
}
pool.check = func() {
pool.wg.Add(1)
_ = pool.pool.Invoke(newUnit(pkg.RandHost(), CheckSource))
if pool.failedCount > breakThreshold {
// 当报错次数超过上限是, 结束任务
pool.recover()
pool.cancel()
}
}
2022-09-08 15:57:17 +08:00
}
p, _ := ants.NewPoolWithFunc(config.Thread, func(i interface{}) {
unit := i.(*Unit)
req, err := pool.genReq(unit.path)
if err != nil {
logs.Log.Error(err.Error())
return
}
var bl *pkg.Baseline
2022-09-23 11:20:41 +08:00
resp, reqerr := pool.client.Do(pctx, req)
if pool.ClientType == ihttp.FAST {
defer fasthttp.ReleaseResponse(resp.FastResponse)
defer fasthttp.ReleaseRequest(req.FastRequest)
}
2022-09-23 11:20:41 +08:00
if reqerr != nil && reqerr != fasthttp.ErrBodyTooLarge {
2022-10-19 16:38:23 +08:00
pool.failedCount++
bl = &pkg.Baseline{Url: pool.BaseURL + unit.path, Err: reqerr.Error(), Reason: ErrRequestFailed.Error()}
pool.failedBaselines = append(pool.failedBaselines, bl)
2022-09-08 15:57:17 +08:00
} else {
if err = pool.PreCompare(resp); unit.source == CheckSource || unit.source == InitSource || err == nil {
2022-09-08 15:57:17 +08:00
// 通过预对比跳过一些无用数据, 减少性能消耗
bl = pkg.NewBaseline(req.URI(), req.Host(), resp)
pool.addFuzzyBaseline(bl)
2022-09-08 15:57:17 +08:00
} else {
bl = pkg.NewInvalidBaseline(req.URI(), req.Host(), resp, err.Error())
2022-09-08 15:57:17 +08:00
}
}
switch unit.source {
2022-10-27 23:40:15 +08:00
case InitSource:
pool.base = bl
pool.addFuzzyBaseline(bl)
2022-10-27 23:40:15 +08:00
pool.initwg.Done()
return
2022-09-20 18:09:06 +08:00
case CheckSource:
if bl.Err != "" {
logs.Log.Warnf("[check.error] maybe ip had banned by waf, break (%d/%d), error: %s", pool.failedCount, breakThreshold, bl.Err)
pool.failedBaselines = append(pool.failedBaselines, bl)
} else if i := pool.base.Compare(bl); i < 1 {
if i == 0 {
logs.Log.Debug("[check.fuzzy] maybe trigger risk control, " + bl.String())
} else {
logs.Log.Warn("[check.failed] maybe trigger risk control, " + bl.String())
}
pool.failedBaselines = append(pool.failedBaselines, bl)
} else {
pool.resetFailed() // 如果后续访问正常, 重置错误次数
logs.Log.Debug("[check.pass] " + bl.String())
}
2022-09-08 15:57:17 +08:00
case WordSource:
// 异步进行性能消耗较大的深度对比
pool.tempCh <- bl
2022-11-11 10:37:30 +08:00
pool.reqCount++
2022-09-23 11:20:41 +08:00
if pool.reqCount%pool.checkPeriod == 0 {
pool.reqCount++
2022-09-23 11:20:41 +08:00
go pool.check()
2022-11-09 17:28:51 +08:00
} else if pool.failedCount%pool.errPeriod == 0 {
pool.failedCount++
2022-09-23 11:20:41 +08:00
go pool.check()
}
pool.bar.Done()
2022-09-08 15:57:17 +08:00
}
2022-10-19 16:38:23 +08:00
2022-09-08 15:57:17 +08:00
pool.wg.Done()
})
pool.pool = p
go func() {
for bl := range pool.tempCh {
if pool.customCompare != nil {
if pool.customCompare(bl) {
pool.OutputCh <- bl
}
} else {
pool.BaseCompare(bl)
}
}
pool.analyzeDone = true
}()
2022-09-08 15:57:17 +08:00
return pool, nil
}
type Pool struct {
*pkg.Config
client *ihttp.Client
pool *ants.PoolWithFunc
bar *pkg.Bar
ctx context.Context
cancel context.CancelFunc
tempCh chan *pkg.Baseline // 待处理的baseline
reqCount int
failedCount int
checkPeriod int
errPeriod int
failedBaselines []*pkg.Baseline
base *pkg.Baseline
baselines map[int]*pkg.Baseline
analyzeDone bool
genReq func(s string) (*ihttp.Request, error)
check func()
customCompare func(*pkg.Baseline) bool
worder *words.Worder
wg sync.WaitGroup
initwg sync.WaitGroup // 初始化用, 之后改成锁
2022-09-08 15:57:17 +08:00
}
func (p *Pool) Init() error {
2022-09-23 11:20:41 +08:00
p.initwg.Add(1)
2022-10-28 00:46:54 +08:00
p.pool.Invoke(newUnit(pkg.RandPath(), InitSource))
2022-09-23 11:20:41 +08:00
p.initwg.Wait()
2022-09-08 15:57:17 +08:00
// todo 分析baseline
// 检测基本访问能力
if p.base.Err != "" {
2022-09-23 01:47:24 +08:00
p.cancel()
return fmt.Errorf(p.base.String())
2022-09-08 15:57:17 +08:00
}
2022-09-23 11:20:41 +08:00
p.base.Collect()
logs.Log.Important("[baseline.init] " + p.base.String())
2022-09-23 11:20:41 +08:00
if p.base.RedirectURL != "" {
CheckRedirect = func(redirectURL string) bool {
2022-09-23 11:20:41 +08:00
if redirectURL == p.base.RedirectURL {
2022-09-08 15:57:17 +08:00
// 相同的RedirectURL将被认为是无效数据
return false
} else {
// path为3xx, 且与baseline中的RedirectURL不同时, 为有效数据
return true
2022-09-08 15:57:17 +08:00
}
}
}
return nil
}
2022-11-10 15:48:38 +08:00
func (p *Pool) Run(ctx context.Context, offset, limit int) {
2022-09-15 19:27:07 +08:00
Loop:
for {
select {
case u, ok := <-p.worder.C:
if !ok {
break Loop
}
2022-11-10 15:48:38 +08:00
if p.reqCount < offset {
p.reqCount++
continue
}
if p.reqCount > limit {
2022-11-10 15:48:38 +08:00
break Loop
}
for _, fn := range p.Fns {
u = fn(u)
}
2022-11-10 04:48:07 +08:00
if u == "" {
continue
}
p.wg.Add(1)
_ = p.pool.Invoke(newUnit(u, WordSource))
2022-09-15 19:27:07 +08:00
case <-ctx.Done():
break Loop
2022-09-19 14:42:29 +08:00
case <-p.ctx.Done():
break Loop
2022-09-15 19:27:07 +08:00
}
2022-09-08 15:57:17 +08:00
}
2022-09-23 11:20:41 +08:00
p.Close()
2022-09-08 15:57:17 +08:00
}
func (p *Pool) PreCompare(resp *ihttp.Response) error {
if p.base != nil && p.base.Status != 200 && p.base.Status == resp.StatusCode() {
return ErrSameStatus
}
if !CheckStatusCode(resp.StatusCode()) {
2022-09-15 19:27:07 +08:00
return ErrBadStatus
2022-09-08 15:57:17 +08:00
}
if CheckRedirect != nil && !CheckRedirect(string(resp.GetHeader("Location"))) {
2022-09-15 19:27:07 +08:00
return ErrRedirect
2022-09-08 15:57:17 +08:00
}
2022-09-26 17:19:08 +08:00
if CheckWaf != nil && !CheckWaf(nil) {
// todo check waf
return ErrWaf
}
2022-09-08 15:57:17 +08:00
2022-09-15 19:27:07 +08:00
return nil
2022-09-08 15:57:17 +08:00
}
func (p *Pool) BaseCompare(bl *pkg.Baseline) {
if !bl.IsValid {
// precompare 确认无效数据直接送入管道
p.OutputCh <- bl
return
}
2022-11-11 11:55:49 +08:00
var status = -1
base, ok := p.baselines[bl.Status] // 挑选对应状态码的baseline进行compare
if !ok && p.base.Status == bl.Status {
// 当other的状态码与base相同时, 会使用base
ok = true
base = p.base
}
2022-11-11 11:55:49 +08:00
if ok {
if status = base.Compare(bl); status == 1 {
p.PutToInvalid(bl, "compare failed")
return
}
}
2022-11-11 11:55:49 +08:00
if status == 0 {
bl.Collect()
for _, f := range bl.Frameworks {
if f.Tag == "waf/cdn" {
p.PutToInvalid(bl, "waf")
return
}
}
if ok && base.FuzzyCompare(bl) {
p.PutToInvalid(bl, "fuzzy compare failed")
p.PutToFuzzy(bl)
return
}
}
2022-09-23 11:20:41 +08:00
p.OutputCh <- bl
}
2022-09-26 17:19:08 +08:00
func (p *Pool) addFuzzyBaseline(bl *pkg.Baseline) {
if _, ok := p.baselines[bl.Status]; !ok && IntsContains(FuzzyStatus, bl.Status) {
bl.Collect()
p.baselines[bl.Status] = bl
logs.Log.Importantf("[baseline.%dinit] %s", bl.Status, bl.String())
}
}
2022-11-10 21:18:26 +08:00
func (p *Pool) PutToInvalid(bl *pkg.Baseline, reason string) {
bl.IsValid = false
bl.Reason = reason
p.OutputCh <- bl
2022-11-10 21:18:26 +08:00
}
func (p *Pool) PutToFuzzy(bl *pkg.Baseline) {
bl.IsFuzzy = true
p.FuzzyCh <- bl
2022-11-10 21:18:26 +08:00
}
func (p *Pool) resetFailed() {
2022-11-11 10:37:30 +08:00
p.failedCount = 1
p.failedBaselines = nil
}
func (p *Pool) recover() {
logs.Log.Errorf("failed request exceeds the threshold , task will exit. Breakpoint %d", p.reqCount)
logs.Log.Error("collecting failed check")
for i, bl := range p.failedBaselines {
logs.Log.Errorf("[failed.%d] %s", i, bl.String())
}
}
2022-09-23 11:20:41 +08:00
func (p *Pool) Close() {
p.wg.Wait()
p.bar.Close()
close(p.tempCh)
for !p.analyzeDone {
time.Sleep(time.Duration(100) * time.Millisecond)
}
}
2022-09-26 17:19:08 +08:00
func (p *Pool) buildPathRequest(path string) (*ihttp.Request, error) {
if p.Config.ClientType == ihttp.FAST {
req := fasthttp.AcquireRequest()
req.SetRequestURI(p.BaseURL + path)
2022-10-27 23:40:15 +08:00
return &ihttp.Request{FastRequest: req, ClientType: p.ClientType}, nil
} else {
req, err := http.NewRequest("GET", p.BaseURL+path, nil)
2022-10-27 23:40:15 +08:00
return &ihttp.Request{StandardRequest: req, ClientType: p.ClientType}, err
}
}
func (p *Pool) buildHostRequest(host string) (*ihttp.Request, error) {
if p.Config.ClientType == ihttp.FAST {
req := fasthttp.AcquireRequest()
req.SetRequestURI(p.BaseURL)
req.SetHost(host)
2022-10-27 23:40:15 +08:00
return &ihttp.Request{FastRequest: req, ClientType: p.ClientType}, nil
} else {
req, err := http.NewRequest("GET", p.BaseURL, nil)
req.Host = host
2022-10-27 23:40:15 +08:00
return &ihttp.Request{StandardRequest: req, ClientType: p.ClientType}, err
}
}