From 411f24d94dbe60f0c940a040ff9fb7b333d8ed31 Mon Sep 17 00:00:00 2001 From: M09Ic Date: Mon, 12 Aug 2024 15:12:43 +0800 Subject: [PATCH] fix the bug of thread pool hanging --- internal/option.go | 2 +- internal/pool/brutepool.go | 6 +++--- internal/pool/checkpool.go | 30 ++++++++++++++++++++++-------- internal/pool/pool.go | 2 +- internal/runner.go | 2 +- 5 files changed, 28 insertions(+), 14 deletions(-) diff --git a/internal/option.go b/internal/option.go index e1d0b32..5783724 100644 --- a/internal/option.go +++ b/internal/option.go @@ -349,7 +349,7 @@ func (opt *Option) NewRunner() (*Runner, error) { } opt.PrintPlugin() - if r.IsCheck == false { + if r.IsCheck == true { logs.Log.Important("enabling brute mod, because of enabled brute plugin") } diff --git a/internal/pool/brutepool.go b/internal/pool/brutepool.go index 146e629..3484f16 100644 --- a/internal/pool/brutepool.go +++ b/internal/pool/brutepool.go @@ -51,7 +51,7 @@ func NewBrutePool(ctx context.Context, config *Config) (*BrutePool, error) { additionCh: make(chan *Unit, config.Thread), closeCh: make(chan struct{}), processCh: make(chan *pkg.Baseline, config.Thread), - wg: sync.WaitGroup{}, + wg: &sync.WaitGroup{}, }, base: u.Scheme + "://" + u.Host, isDir: strings.HasSuffix(u.Path, "/"), @@ -196,7 +196,7 @@ func (pool *BrutePool) Upgrade(bl *pkg.Baseline) error { return nil } -func (pool *BrutePool) Run(offset, limit int) { +func (pool *BrutePool) Run(ctx context.Context, offset, limit int) { pool.Worder.Run() if pool.Active { pool.wg.Add(1) @@ -279,7 +279,7 @@ Loop: } case <-pool.closeCh: break Loop - case <-pool.ctx.Done(): + case <-ctx.Done(): break Loop case <-pool.ctx.Done(): break Loop diff --git a/internal/pool/checkpool.go b/internal/pool/checkpool.go index 45bc087..70f333f 100644 --- a/internal/pool/checkpool.go +++ b/internal/pool/checkpool.go @@ -29,7 +29,7 @@ func NewCheckPool(ctx context.Context, config *Config) (*CheckPool, error) { Timeout: time.Duration(config.Timeout) * time.Second, ProxyAddr: config.ProxyAddr, }), - wg: sync.WaitGroup{}, + wg: &sync.WaitGroup{}, additionCh: make(chan *Unit, 1024), closeCh: make(chan struct{}), processCh: make(chan *pkg.Baseline, config.Thread), @@ -50,21 +50,35 @@ type CheckPool struct { func (pool *CheckPool) Run(ctx context.Context, offset, limit int) { pool.Worder.Run() + var done bool + // 挂起一个监控goroutine, 每100ms判断一次done, 如果已经done, 则关闭closeCh, 然后通过Loop中的select case closeCh去break, 实现退出 + go func() { + for { + if done { + pool.wg.Wait() + close(pool.closeCh) + return + } + time.Sleep(100 * time.Millisecond) + } + }() + Loop: for { select { case u, ok := <-pool.Worder.C: if !ok { - break Loop + done = true + continue } if pool.reqCount < offset { pool.reqCount++ - break Loop + continue } if pool.reqCount > limit { - break Loop + continue } pool.wg.Add(1) @@ -82,7 +96,7 @@ Loop: break Loop } } - pool.wg.Wait() + pool.Close() } @@ -128,6 +142,9 @@ func (pool *CheckPool) Invoke(v interface{}) { } else { bl = pkg.NewBaseline(req.URI(), req.Host(), resp) bl.Collect() + if bl.Status == 400 { + pool.doUpgrade(bl) + } } bl.ReqDepth = unit.depth bl.Source = unit.source @@ -141,9 +158,6 @@ func (pool *CheckPool) Handler() { if bl.RedirectURL != "" { pool.doRedirect(bl, bl.ReqDepth) pool.putToOutput(bl) - } else if bl.Status == 400 { - pool.doUpgrade(bl) - pool.putToOutput(bl) } else { params := map[string]interface{}{ "current": bl, diff --git a/internal/pool/pool.go b/internal/pool/pool.go index b62e70c..3ba1ddf 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -29,7 +29,7 @@ type BasePool struct { failedCount int additionCh chan *Unit closeCh chan struct{} - wg sync.WaitGroup + wg *sync.WaitGroup } func (pool *BasePool) doRedirect(bl *pkg.Baseline, depth int) { diff --git a/internal/runner.go b/internal/runner.go index aeeae11..a334a40 100644 --- a/internal/runner.go +++ b/internal/runner.go @@ -207,7 +207,7 @@ func (r *Runner) Prepare(ctx context.Context) error { } } - brutePool.Run(brutePool.Statistor.Offset, limit) + brutePool.Run(ctx, brutePool.Statistor.Offset, limit) if brutePool.IsFailed && len(brutePool.FailedBaselines) > 0 { // 如果因为错误积累退出, end将指向第一个错误发生时, 防止resume时跳过大量目标