修复在优化stat时一个线程安全问题导致程序阻塞的bug

This commit is contained in:
M09Ic 2023-01-29 18:23:55 +08:00
parent 8152ae1b1d
commit 940c5b9e99
3 changed files with 24 additions and 24 deletions

View File

@ -49,8 +49,8 @@ func NewPool(ctx context.Context, config *pkg.Config) (*Pool, error) {
client: ihttp.NewClient(config.Thread, 2, config.ClientType),
baselines: make(map[int]*pkg.Baseline),
urls: make(map[string]struct{}),
tempCh: make(chan *pkg.Baseline, config.Thread),
checkCh: make(chan int),
tempCh: make(chan *pkg.Baseline, 100),
checkCh: make(chan int, 100),
additionCh: make(chan *Unit, 100),
closeCh: make(chan struct{}),
waiter: sync.WaitGroup{},
@ -77,7 +77,6 @@ func NewPool(ctx context.Context, config *pkg.Config) (*Pool, error) {
if bl.IsValid {
pool.addFuzzyBaseline(bl)
}
if _, ok := pool.Statistor.Counts[bl.Status]; ok {
pool.Statistor.Counts[bl.Status]++
} else {
@ -140,6 +139,7 @@ func NewPool(ctx context.Context, config *pkg.Config) (*Pool, error) {
}
}
if !pool.closed {
// 如果任务被取消, 所有还没处理的请求结果都会被丢弃
pool.OutputCh <- bl
}
pool.waiter.Done()
@ -168,7 +168,7 @@ type Pool struct {
closeCh chan struct{}
closed bool
wordOffset int
failedCount int
failedCount int32
isFailed bool
failedBaselines []*pkg.Baseline
random *pkg.Baseline
@ -244,24 +244,22 @@ func (pool *Pool) Run(ctx context.Context, offset, limit int) {
pool.waiter.Wait()
close(pool.closeCh)
return
} else {
time.Sleep(100)
}
time.Sleep(100 * time.Millisecond)
}
}()
Loop:
for {
select {
case u, ok := <-pool.worder.C:
case w, ok := <-pool.worder.C:
if !ok {
done = true
continue
}
pool.Statistor.End++
pool.wordOffset++
if pool.wordOffset <= offset {
if pool.wordOffset < offset {
continue
}
@ -271,8 +269,8 @@ Loop:
}
pool.waiter.Add(1)
pool.urls[u] = struct{}{}
pool.reqPool.Invoke(newUnitWithNumber(pool.safePath(u), WordSource, pool.wordOffset)) // 原样的目录拼接, 输入了几个"/"就是几个, 适配java的目录解析
pool.urls[w] = struct{}{}
pool.reqPool.Invoke(newUnitWithNumber(pool.safePath(w), WordSource, pool.wordOffset)) // 原样的目录拼接, 输入了几个"/"就是几个, 适配java的目录解析
case source := <-pool.checkCh:
pool.Statistor.CheckNumber++
if pool.Mod == pkg.HostSpray {
@ -327,7 +325,7 @@ func (pool *Pool) Invoke(v interface{}) {
// compare与各种错误处理
var bl *pkg.Baseline
if reqerr != nil && reqerr != fasthttp.ErrBodyTooLarge {
pool.failedCount++
atomic.AddInt32(&pool.failedCount, 1)
atomic.AddInt32(&pool.Statistor.FailedNumber, 1)
bl = &pkg.Baseline{UrlString: pool.base + unit.path, IsValid: false, ErrString: reqerr.Error(), Reason: ErrRequestFailed.Error()}
pool.failedBaselines = append(pool.failedBaselines, bl)
@ -387,7 +385,7 @@ func (pool *Pool) Invoke(v interface{}) {
logs.Log.Warn("[check.fuzzy] maybe trigger risk control, " + bl.String())
}
} else {
pool.failedCount += 2
atomic.AddInt32(&pool.failedCount, 1) //
logs.Log.Warn("[check.failed] maybe trigger risk control, " + bl.String())
pool.failedBaselines = append(pool.failedBaselines, bl)
}
@ -399,10 +397,10 @@ func (pool *Pool) Invoke(v interface{}) {
case WordSource:
// 异步进行性能消耗较大的深度对比
pool.tempCh <- bl
if pool.wordOffset%pool.CheckPeriod == 0 {
if int(pool.Statistor.ReqTotal)%pool.CheckPeriod == 0 {
pool.doCheck()
} else if pool.failedCount%pool.ErrPeriod == 0 {
pool.failedCount++
atomic.AddInt32(&pool.failedCount, 1)
pool.doCheck()
}
pool.bar.Done()

View File

@ -91,8 +91,8 @@ func (r *Runner) PrepareConfig() *pkg.Config {
FuzzyCh: r.FuzzyCh,
Fuzzy: r.Fuzzy,
CheckPeriod: r.CheckPeriod,
ErrPeriod: r.ErrPeriod,
BreakThreshold: r.BreakThreshold,
ErrPeriod: int32(r.ErrPeriod),
BreakThreshold: int32(r.BreakThreshold),
MatchExpr: r.MatchExpr,
FilterExpr: r.FilterExpr,
RecuExpr: r.RecursiveExpr,
@ -251,11 +251,13 @@ Loop:
for {
select {
case <-ctx.Done():
for t := range r.taskCh {
stat := pkg.NewStatistor(t.baseUrl)
r.StatFile.SafeWrite(stat.Json())
if len(r.taskCh) > 0 {
for t := range r.taskCh {
stat := pkg.NewStatistor(t.baseUrl)
r.StatFile.SafeWrite(stat.Json())
}
}
logs.Log.Importantf("save all stat to %s", r.StatFile.Filename)
logs.Log.Importantf("already save all stat to %s", r.StatFile.Filename)
break Loop
case t, ok := <-r.taskCh:
if !ok {
@ -266,7 +268,7 @@ Loop:
}
r.poolwg.Wait()
//time.Sleep(100 * time.Millisecond) // 延迟100ms, 等所有数据处理完毕
time.Sleep(100 * time.Millisecond) // 延迟100ms, 等所有数据处理完毕
for {
if len(r.OutputCh) == 0 {
close(r.OutputCh)

View File

@ -26,8 +26,8 @@ type Config struct {
Timeout int
RateLimit int
CheckPeriod int
ErrPeriod int
BreakThreshold int
ErrPeriod int32
BreakThreshold int32
Method string
Mod SprayMod
Headers map[string]string