diff --git a/internal/pool.go b/internal/pool.go index 3d157c5..6e55b70 100644 --- a/internal/pool.go +++ b/internal/pool.go @@ -51,6 +51,7 @@ func NewPool(ctx context.Context, config *pkg.Config) (*Pool, error) { tempCh: make(chan *pkg.Baseline, config.Thread), checkCh: make(chan int), additionCh: make(chan *Unit, 100), + closeCh: make(chan struct{}), waiter: sync.WaitGroup{}, initwg: sync.WaitGroup{}, limiter: rate.NewLimiter(rate.Limit(config.RateLimit), 1), @@ -137,7 +138,9 @@ func NewPool(ctx context.Context, config *pkg.Config) (*Pool, error) { } } } - pool.OutputCh <- bl + if !pool.closed { + pool.OutputCh <- bl + } pool.waiter.Done() } @@ -161,6 +164,8 @@ type Pool struct { tempCh chan *pkg.Baseline // 待处理的baseline checkCh chan int // 独立的check管道, 防止与redirect/crawl冲突 additionCh chan *Unit + closeCh chan struct{} + closed bool wordOffset int failedCount int isFailed bool @@ -231,22 +236,26 @@ func (pool *Pool) Run(ctx context.Context, offset, limit int) { go pool.doCommonFile() } - closeCh := make(chan struct{}) - var worderDone bool - wait := func() { - if !worderDone { - worderDone = true - pool.waiter.Wait() - close(closeCh) + var done bool + go func() { + for { + if done { + pool.waiter.Wait() + close(pool.closeCh) + return + } else { + time.Sleep(100) + } } - } + + }() Loop: for { select { case u, ok := <-pool.worder.C: if !ok { - go wait() + done = true continue } pool.Statistor.End++ @@ -256,7 +265,7 @@ Loop: } if pool.Statistor.End > limit { - go wait() + done = true continue } @@ -271,7 +280,7 @@ Loop: pool.reqPool.Invoke(newUnitWithNumber(pool.safePath(pkg.RandPath()), source, pool.wordOffset)) } case unit, ok := <-pool.additionCh: - if !ok { + if !ok || pool.closed { continue } if _, ok := pool.urls[unit.path]; ok { @@ -282,7 +291,7 @@ Loop: unit.number = pool.wordOffset pool.reqPool.Invoke(unit) } - case <-closeCh: + case <-pool.closeCh: break Loop case <-ctx.Done(): break Loop @@ -290,7 +299,7 @@ Loop: break Loop } } - + pool.closed = true pool.Close() } @@ -665,6 +674,11 @@ func (pool *Pool) doCheck() { } func (pool *Pool) addAddition(u *Unit) { + // 强行屏蔽报错, 防止goroutine泄露 + defer func() { + if err := recover(); err != nil { + } + }() pool.additionCh <- u } @@ -702,11 +716,11 @@ func (pool *Pool) recover() { func (pool *Pool) Close() { for pool.analyzeDone { + // 等待缓存的待处理任务完成 time.Sleep(time.Duration(100) * time.Millisecond) } - - close(pool.additionCh) - close(pool.checkCh) + close(pool.additionCh) // 关闭addition管道 + close(pool.checkCh) // 关闭check管道 pool.Statistor.EndTime = time.Now().Unix() pool.bar.Close() }