diff --git a/go.mod b/go.mod index 93f637f..72cd3a0 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/chainreactors/go-metrics v0.0.0-20220926021830-24787b7a10f8 github.com/chainreactors/gogo/v2 v2.11.1-0.20230327070928-b5ff67ac46c7 github.com/chainreactors/logs v0.7.1-0.20230316032643-ed7d85ca234f - github.com/chainreactors/parsers v0.3.1-0.20230327070646-7dbe644d2b3b + github.com/chainreactors/parsers v0.3.1-0.20230403160559-9ed502452575 github.com/chainreactors/words v0.4.1-0.20230327065326-448a905ac8c2 ) diff --git a/go.sum b/go.sum index 1e46284..4343193 100644 --- a/go.sum +++ b/go.sum @@ -24,6 +24,8 @@ github.com/chainreactors/parsers v0.3.0/go.mod h1:Z9weht+lnFCk7UcwqFu6lXpS7u5vtt github.com/chainreactors/parsers v0.3.1-0.20230313041950-25d5f9059c79/go.mod h1:tA33N6UbYFnIT3k5tufOMfETxmEP20RZFyTSEnVXNUA= github.com/chainreactors/parsers v0.3.1-0.20230327070646-7dbe644d2b3b h1:EubRBdVAj9COEmfkCB2yseejcAhDgndRipp/zDzJ0FU= github.com/chainreactors/parsers v0.3.1-0.20230327070646-7dbe644d2b3b/go.mod h1:tA33N6UbYFnIT3k5tufOMfETxmEP20RZFyTSEnVXNUA= +github.com/chainreactors/parsers v0.3.1-0.20230403160559-9ed502452575 h1:uHE9O8x70FXwge5p68U/lGC9Xs8Leg8hWJR9PHKGzsk= +github.com/chainreactors/parsers v0.3.1-0.20230403160559-9ed502452575/go.mod h1:tA33N6UbYFnIT3k5tufOMfETxmEP20RZFyTSEnVXNUA= github.com/chainreactors/utils v0.0.14-0.20230314084720-a4d745cabc56 h1:1uhvEh7Of4fQJXRMsfGEZGy5NcETsM2yataQ0oYSw0k= github.com/chainreactors/utils v0.0.14-0.20230314084720-a4d745cabc56/go.mod h1:NKSu1V6EC4wa8QHtPfiJHlH9VjGfUQOx5HADK0xry3Y= github.com/chainreactors/words v0.4.1-0.20230327065326-448a905ac8c2 h1:/v8gTORQIRJl2lgNt82OOeP/04QZyNTGKcmjfstVN5E= diff --git a/internal/checkpool.go b/internal/checkpool.go index 8f90fdb..acacb5d 100644 --- a/internal/checkpool.go +++ b/internal/checkpool.go @@ -10,10 +10,13 @@ import ( "github.com/chainreactors/words" "github.com/panjf2000/ants/v2" "github.com/valyala/fasthttp" + "net/url" + "strings" "sync" "time" ) +// 类似httpx的无状态, 无scope, 无并发池的检测模式 func NewCheckPool(ctx context.Context, config *pkg.Config) (*CheckPool, error) { pctx, cancel := context.WithCancel(ctx) pool := &CheckPool{ @@ -22,47 +25,13 @@ func NewCheckPool(ctx context.Context, config *pkg.Config) (*CheckPool, error) { cancel: cancel, client: ihttp.NewClient(config.Thread, 2, config.ClientType), wg: sync.WaitGroup{}, + additionCh: make(chan *Unit, 100), + closeCh: make(chan struct{}), reqCount: 1, failedCount: 1, } - p, _ := ants.NewPoolWithFunc(config.Thread, func(i interface{}) { - unit := i.(*Unit) - req, err := pool.genReq(unit.path) - if err != nil { - logs.Log.Error(err.Error()) - } - - start := time.Now() - var bl *pkg.Baseline - resp, reqerr := pool.client.Do(pctx, req) - if pool.ClientType == ihttp.FAST { - defer fasthttp.ReleaseResponse(resp.FastResponse) - defer fasthttp.ReleaseRequest(req.FastRequest) - } - - if reqerr != nil && reqerr != fasthttp.ErrBodyTooLarge { - pool.failedCount++ - - bl = &pkg.Baseline{ - SprayResult: &parsers.SprayResult{ - UrlString: pool.BaseURL + unit.path, - IsValid: false, - ErrString: reqerr.Error(), - Reason: ErrRequestFailed.Error(), - }, - } - } else { - bl = pkg.NewBaseline(req.URI(), req.Host(), resp) - bl.Collect() - } - bl.Source = unit.source - bl.Spended = time.Since(start).Milliseconds() - pool.OutputCh <- bl - pool.reqCount++ - pool.wg.Done() - pool.bar.Done() - }) + p, _ := ants.NewPoolWithFunc(config.Thread, pool.Invoke) pool.pool = p return pool, nil @@ -77,50 +46,174 @@ type CheckPool struct { cancel context.CancelFunc reqCount int failedCount int + additionCh chan *Unit + closeCh chan struct{} worder *words.Worder wg sync.WaitGroup } -func (p *CheckPool) Close() { - p.bar.Close() +func (pool *CheckPool) Close() { + pool.bar.Close() } -func (p *CheckPool) genReq(s string) (*ihttp.Request, error) { - if p.Mod == pkg.HostSpray { - return ihttp.BuildHostRequest(p.ClientType, p.BaseURL, s) - } else if p.Mod == pkg.PathSpray { - return ihttp.BuildPathRequest(p.ClientType, p.BaseURL, s) +func (pool *CheckPool) genReq(s string) (*ihttp.Request, error) { + if pool.Mod == pkg.HostSpray { + return ihttp.BuildHostRequest(pool.ClientType, pool.BaseURL, s) + } else if pool.Mod == pkg.PathSpray { + return ihttp.BuildPathRequest(pool.ClientType, pool.BaseURL, s) } return nil, fmt.Errorf("unknown mod") } -func (p *CheckPool) Run(ctx context.Context, offset, limit int) { - p.worder.Run() +func (pool *CheckPool) Run(ctx context.Context, offset, limit int) { + pool.worder.Run() + + var done bool + // 挂起一个监控goroutine, 每100ms判断一次done, 如果已经done, 则关闭closeCh, 然后通过Loop中的select case closeCh去break, 实现退出 + go func() { + for { + if done { + pool.wg.Wait() + close(pool.closeCh) + return + } + time.Sleep(100 * time.Millisecond) + } + }() + Loop: for { select { - case u, ok := <-p.worder.C: + case u, ok := <-pool.worder.C: if !ok { - break Loop - } - - if p.reqCount < offset { - p.reqCount++ + done = true continue } - if p.reqCount > limit { - break Loop + if pool.reqCount < offset { + pool.reqCount++ + continue } - p.wg.Add(1) - _ = p.pool.Invoke(newUnit(u, CheckSource)) + if pool.reqCount > limit { + continue + } + + pool.wg.Add(1) + _ = pool.pool.Invoke(newUnit(u, CheckSource)) + case u, ok := <-pool.additionCh: + if !ok { + continue + } + _ = pool.pool.Invoke(u) + case <-pool.closeCh: + break Loop case <-ctx.Done(): break Loop - case <-p.ctx.Done(): + case <-pool.ctx.Done(): break Loop } } - p.wg.Wait() - p.Close() + + pool.Close() +} + +func (pool *CheckPool) Invoke(v interface{}) { + unit := v.(*Unit) + req, err := pool.genReq(unit.path) + if err != nil { + logs.Log.Error(err.Error()) + } + + start := time.Now() + var bl *pkg.Baseline + resp, reqerr := pool.client.Do(pool.ctx, req) + if pool.ClientType == ihttp.FAST { + defer fasthttp.ReleaseResponse(resp.FastResponse) + defer fasthttp.ReleaseRequest(req.FastRequest) + } + + if reqerr != nil && reqerr != fasthttp.ErrBodyTooLarge { + pool.failedCount++ + bl = &pkg.Baseline{ + SprayResult: &parsers.SprayResult{ + UrlString: unit.path, + IsValid: false, + ErrString: reqerr.Error(), + Reason: ErrRequestFailed.Error(), + ReqDepth: unit.depth, + }, + } + pool.doUpgrade(bl) + } else { + bl = pkg.NewBaseline(req.URI(), req.Host(), resp) + bl.Collect() + } + bl.ReqDepth = unit.depth + bl.Source = unit.source + bl.Spended = time.Since(start).Milliseconds() + + // 手动处理重定向 + if bl.IsValid { + if bl.RedirectURL != "" { + pool.doRedirect(bl, unit.depth) + pool.FuzzyCh <- bl + } else if bl.Status == 400 { + pool.doUpgrade(bl) + pool.FuzzyCh <- bl + } else { + pool.OutputCh <- bl + } + } + + pool.reqCount++ + pool.wg.Done() + pool.bar.Done() +} + +func (pool *CheckPool) doRedirect(bl *pkg.Baseline, depth int) { + if depth >= MaxRedirect { + return + } + var reURL string + if strings.HasPrefix(bl.RedirectURL, "http") { + _, err := url.Parse(bl.RedirectURL) + if err != nil { + return + } + reURL = bl.RedirectURL + } else { + reURL = bl.BaseURL() + FormatURL(bl.BaseURL(), bl.RedirectURL) + } + + pool.wg.Add(1) + go func() { + pool.additionCh <- &Unit{ + path: reURL, + source: RedirectSource, + frontUrl: bl.UrlString, + depth: depth + 1, + } + }() +} + +// tcp与400进行协议转换 +func (pool *CheckPool) doUpgrade(bl *pkg.Baseline) { + if bl.ReqDepth >= 1 { + return + } + pool.wg.Add(1) + var reurl string + if strings.HasPrefix(bl.UrlString, "https") { + reurl = strings.Replace(bl.UrlString, "https", "http", 1) + } else { + reurl = strings.Replace(bl.UrlString, "http", "https", 1) + } + go func() { + pool.additionCh <- &Unit{ + path: reurl, + source: UpgradeSource, + depth: bl.ReqDepth + 1, + } + }() } diff --git a/internal/option.go b/internal/option.go index de601b4..6440b42 100644 --- a/internal/option.go +++ b/internal/option.go @@ -536,6 +536,7 @@ func (opt *Option) PrepareRunner() (*Runner, error) { r.Probes = strings.Split(opt.OutputProbe, ",") } + // init output file if opt.OutputFile != "" { r.OutputFile, err = files.NewFile(opt.OutputFile, false, false, true) if err != nil { diff --git a/internal/pool.go b/internal/pool.go index 09446a2..f8015d6 100644 --- a/internal/pool.go +++ b/internal/pool.go @@ -169,6 +169,7 @@ func (pool *Pool) Run(ctx context.Context, offset, limit int) { } var done bool + // 挂起一个监控goroutine, 每100ms判断一次done, 如果已经done, 则关闭closeCh, 然后通过Loop中的select case closeCh去break, 实现退出 go func() { for { if done { @@ -300,7 +301,7 @@ func (pool *Pool) Invoke(v interface{}) { // 手动处理重定向 if bl.IsValid && unit.source != CheckSource && bl.RedirectURL != "" { - pool.waiter.Add(1) + //pool.waiter.Add(1) pool.doRedirect(bl, unit.depth) } @@ -590,19 +591,21 @@ func (pool *Pool) Upgrade(bl *pkg.Baseline) error { } func (pool *Pool) doRedirect(bl *pkg.Baseline, depth int) { - defer pool.waiter.Done() if depth >= MaxRedirect { return } reURL := FormatURL(bl.Url.Path, bl.RedirectURL) pool.waiter.Add(1) - go pool.addAddition(&Unit{ - path: reURL, - source: RedirectSource, - frontUrl: bl.UrlString, - depth: depth + 1, - }) + go func() { + defer pool.waiter.Done() + pool.addAddition(&Unit{ + path: reURL, + source: RedirectSource, + frontUrl: bl.UrlString, + depth: depth + 1, + }) + }() } func (pool *Pool) doCrawl(bl *pkg.Baseline) { diff --git a/internal/runner.go b/internal/runner.go index 8dcf928..cf70ecc 100644 --- a/internal/runner.go +++ b/internal/runner.go @@ -283,14 +283,12 @@ Loop: time.Sleep(100 * time.Millisecond) // 延迟100ms, 等所有数据处理完毕 for { if len(r.OutputCh) == 0 { - close(r.OutputCh) break } } for { if len(r.FuzzyCh) == 0 { - close(r.FuzzyCh) break } } @@ -322,7 +320,6 @@ Loop: for { if len(r.OutputCh) == 0 { - close(r.OutputCh) break } } diff --git a/internal/types.go b/internal/types.go index ea49baf..b54748b 100644 --- a/internal/types.go +++ b/internal/types.go @@ -53,6 +53,7 @@ const ( RuleSource BakSource CommonFileSource + UpgradeSource ) func newUnit(path string, source int) *Unit { diff --git a/internal/utils.go b/internal/utils.go index a650b56..7626cc9 100644 --- a/internal/utils.go +++ b/internal/utils.go @@ -224,19 +224,12 @@ func FormatURL(base, u string) string { if err != nil { return "" } - if len(parsed.Path) <= 1 { - return "" - } return parsed.Path } else if strings.HasPrefix(u, "//") { parsed, err := url.Parse(u) if err != nil { return "" } - if len(parsed.Path) <= 1 { - // 跳过"/"与空目录 - return "" - } return parsed.Path } else if strings.HasPrefix(u, "/") { // 绝对目录拼接 diff --git a/pkg/baseline.go b/pkg/baseline.go index ec76124..93a2138 100644 --- a/pkg/baseline.go +++ b/pkg/baseline.go @@ -117,6 +117,10 @@ func (bl *Baseline) IsDir() bool { return false } +func (bl *Baseline) BaseURL() string { + return bl.Url.Scheme + "://" + bl.Url.Host +} + // Collect 深度收集信息 func (bl *Baseline) Collect() { if bl.ContentType == "html" || bl.ContentType == "json" || bl.ContentType == "txt" {