mirror of
https://github.com/chainreactors/spray.git
synced 2025-06-22 02:40:41 +00:00
修复在优化stat时一个线程安全问题导致程序阻塞的bug
This commit is contained in:
parent
8152ae1b1d
commit
940c5b9e99
@ -49,8 +49,8 @@ func NewPool(ctx context.Context, config *pkg.Config) (*Pool, error) {
|
|||||||
client: ihttp.NewClient(config.Thread, 2, config.ClientType),
|
client: ihttp.NewClient(config.Thread, 2, config.ClientType),
|
||||||
baselines: make(map[int]*pkg.Baseline),
|
baselines: make(map[int]*pkg.Baseline),
|
||||||
urls: make(map[string]struct{}),
|
urls: make(map[string]struct{}),
|
||||||
tempCh: make(chan *pkg.Baseline, config.Thread),
|
tempCh: make(chan *pkg.Baseline, 100),
|
||||||
checkCh: make(chan int),
|
checkCh: make(chan int, 100),
|
||||||
additionCh: make(chan *Unit, 100),
|
additionCh: make(chan *Unit, 100),
|
||||||
closeCh: make(chan struct{}),
|
closeCh: make(chan struct{}),
|
||||||
waiter: sync.WaitGroup{},
|
waiter: sync.WaitGroup{},
|
||||||
@ -77,7 +77,6 @@ func NewPool(ctx context.Context, config *pkg.Config) (*Pool, error) {
|
|||||||
if bl.IsValid {
|
if bl.IsValid {
|
||||||
pool.addFuzzyBaseline(bl)
|
pool.addFuzzyBaseline(bl)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := pool.Statistor.Counts[bl.Status]; ok {
|
if _, ok := pool.Statistor.Counts[bl.Status]; ok {
|
||||||
pool.Statistor.Counts[bl.Status]++
|
pool.Statistor.Counts[bl.Status]++
|
||||||
} else {
|
} else {
|
||||||
@ -140,6 +139,7 @@ func NewPool(ctx context.Context, config *pkg.Config) (*Pool, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !pool.closed {
|
if !pool.closed {
|
||||||
|
// 如果任务被取消, 所有还没处理的请求结果都会被丢弃
|
||||||
pool.OutputCh <- bl
|
pool.OutputCh <- bl
|
||||||
}
|
}
|
||||||
pool.waiter.Done()
|
pool.waiter.Done()
|
||||||
@ -168,7 +168,7 @@ type Pool struct {
|
|||||||
closeCh chan struct{}
|
closeCh chan struct{}
|
||||||
closed bool
|
closed bool
|
||||||
wordOffset int
|
wordOffset int
|
||||||
failedCount int
|
failedCount int32
|
||||||
isFailed bool
|
isFailed bool
|
||||||
failedBaselines []*pkg.Baseline
|
failedBaselines []*pkg.Baseline
|
||||||
random *pkg.Baseline
|
random *pkg.Baseline
|
||||||
@ -244,24 +244,22 @@ func (pool *Pool) Run(ctx context.Context, offset, limit int) {
|
|||||||
pool.waiter.Wait()
|
pool.waiter.Wait()
|
||||||
close(pool.closeCh)
|
close(pool.closeCh)
|
||||||
return
|
return
|
||||||
} else {
|
|
||||||
time.Sleep(100)
|
|
||||||
}
|
}
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
Loop:
|
Loop:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case u, ok := <-pool.worder.C:
|
case w, ok := <-pool.worder.C:
|
||||||
if !ok {
|
if !ok {
|
||||||
done = true
|
done = true
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
pool.Statistor.End++
|
pool.Statistor.End++
|
||||||
pool.wordOffset++
|
pool.wordOffset++
|
||||||
if pool.wordOffset <= offset {
|
if pool.wordOffset < offset {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -271,8 +269,8 @@ Loop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
pool.waiter.Add(1)
|
pool.waiter.Add(1)
|
||||||
pool.urls[u] = struct{}{}
|
pool.urls[w] = struct{}{}
|
||||||
pool.reqPool.Invoke(newUnitWithNumber(pool.safePath(u), WordSource, pool.wordOffset)) // 原样的目录拼接, 输入了几个"/"就是几个, 适配java的目录解析
|
pool.reqPool.Invoke(newUnitWithNumber(pool.safePath(w), WordSource, pool.wordOffset)) // 原样的目录拼接, 输入了几个"/"就是几个, 适配java的目录解析
|
||||||
case source := <-pool.checkCh:
|
case source := <-pool.checkCh:
|
||||||
pool.Statistor.CheckNumber++
|
pool.Statistor.CheckNumber++
|
||||||
if pool.Mod == pkg.HostSpray {
|
if pool.Mod == pkg.HostSpray {
|
||||||
@ -327,7 +325,7 @@ func (pool *Pool) Invoke(v interface{}) {
|
|||||||
// compare与各种错误处理
|
// compare与各种错误处理
|
||||||
var bl *pkg.Baseline
|
var bl *pkg.Baseline
|
||||||
if reqerr != nil && reqerr != fasthttp.ErrBodyTooLarge {
|
if reqerr != nil && reqerr != fasthttp.ErrBodyTooLarge {
|
||||||
pool.failedCount++
|
atomic.AddInt32(&pool.failedCount, 1)
|
||||||
atomic.AddInt32(&pool.Statistor.FailedNumber, 1)
|
atomic.AddInt32(&pool.Statistor.FailedNumber, 1)
|
||||||
bl = &pkg.Baseline{UrlString: pool.base + unit.path, IsValid: false, ErrString: reqerr.Error(), Reason: ErrRequestFailed.Error()}
|
bl = &pkg.Baseline{UrlString: pool.base + unit.path, IsValid: false, ErrString: reqerr.Error(), Reason: ErrRequestFailed.Error()}
|
||||||
pool.failedBaselines = append(pool.failedBaselines, bl)
|
pool.failedBaselines = append(pool.failedBaselines, bl)
|
||||||
@ -387,7 +385,7 @@ func (pool *Pool) Invoke(v interface{}) {
|
|||||||
logs.Log.Warn("[check.fuzzy] maybe trigger risk control, " + bl.String())
|
logs.Log.Warn("[check.fuzzy] maybe trigger risk control, " + bl.String())
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
pool.failedCount += 2
|
atomic.AddInt32(&pool.failedCount, 1) //
|
||||||
logs.Log.Warn("[check.failed] maybe trigger risk control, " + bl.String())
|
logs.Log.Warn("[check.failed] maybe trigger risk control, " + bl.String())
|
||||||
pool.failedBaselines = append(pool.failedBaselines, bl)
|
pool.failedBaselines = append(pool.failedBaselines, bl)
|
||||||
}
|
}
|
||||||
@ -399,10 +397,10 @@ func (pool *Pool) Invoke(v interface{}) {
|
|||||||
case WordSource:
|
case WordSource:
|
||||||
// 异步进行性能消耗较大的深度对比
|
// 异步进行性能消耗较大的深度对比
|
||||||
pool.tempCh <- bl
|
pool.tempCh <- bl
|
||||||
if pool.wordOffset%pool.CheckPeriod == 0 {
|
if int(pool.Statistor.ReqTotal)%pool.CheckPeriod == 0 {
|
||||||
pool.doCheck()
|
pool.doCheck()
|
||||||
} else if pool.failedCount%pool.ErrPeriod == 0 {
|
} else if pool.failedCount%pool.ErrPeriod == 0 {
|
||||||
pool.failedCount++
|
atomic.AddInt32(&pool.failedCount, 1)
|
||||||
pool.doCheck()
|
pool.doCheck()
|
||||||
}
|
}
|
||||||
pool.bar.Done()
|
pool.bar.Done()
|
||||||
|
@ -91,8 +91,8 @@ func (r *Runner) PrepareConfig() *pkg.Config {
|
|||||||
FuzzyCh: r.FuzzyCh,
|
FuzzyCh: r.FuzzyCh,
|
||||||
Fuzzy: r.Fuzzy,
|
Fuzzy: r.Fuzzy,
|
||||||
CheckPeriod: r.CheckPeriod,
|
CheckPeriod: r.CheckPeriod,
|
||||||
ErrPeriod: r.ErrPeriod,
|
ErrPeriod: int32(r.ErrPeriod),
|
||||||
BreakThreshold: r.BreakThreshold,
|
BreakThreshold: int32(r.BreakThreshold),
|
||||||
MatchExpr: r.MatchExpr,
|
MatchExpr: r.MatchExpr,
|
||||||
FilterExpr: r.FilterExpr,
|
FilterExpr: r.FilterExpr,
|
||||||
RecuExpr: r.RecursiveExpr,
|
RecuExpr: r.RecursiveExpr,
|
||||||
@ -251,11 +251,13 @@ Loop:
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
if len(r.taskCh) > 0 {
|
||||||
for t := range r.taskCh {
|
for t := range r.taskCh {
|
||||||
stat := pkg.NewStatistor(t.baseUrl)
|
stat := pkg.NewStatistor(t.baseUrl)
|
||||||
r.StatFile.SafeWrite(stat.Json())
|
r.StatFile.SafeWrite(stat.Json())
|
||||||
}
|
}
|
||||||
logs.Log.Importantf("save all stat to %s", r.StatFile.Filename)
|
}
|
||||||
|
logs.Log.Importantf("already save all stat to %s", r.StatFile.Filename)
|
||||||
break Loop
|
break Loop
|
||||||
case t, ok := <-r.taskCh:
|
case t, ok := <-r.taskCh:
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -266,7 +268,7 @@ Loop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.poolwg.Wait()
|
r.poolwg.Wait()
|
||||||
//time.Sleep(100 * time.Millisecond) // 延迟100ms, 等所有数据处理完毕
|
time.Sleep(100 * time.Millisecond) // 延迟100ms, 等所有数据处理完毕
|
||||||
for {
|
for {
|
||||||
if len(r.OutputCh) == 0 {
|
if len(r.OutputCh) == 0 {
|
||||||
close(r.OutputCh)
|
close(r.OutputCh)
|
||||||
|
@ -26,8 +26,8 @@ type Config struct {
|
|||||||
Timeout int
|
Timeout int
|
||||||
RateLimit int
|
RateLimit int
|
||||||
CheckPeriod int
|
CheckPeriod int
|
||||||
ErrPeriod int
|
ErrPeriod int32
|
||||||
BreakThreshold int
|
BreakThreshold int32
|
||||||
Method string
|
Method string
|
||||||
Mod SprayMod
|
Mod SprayMod
|
||||||
Headers map[string]string
|
Headers map[string]string
|
||||||
|
Loading…
x
Reference in New Issue
Block a user