diff --git a/cmd/cmd.go b/cmd/cmd.go index 382ddaa..6218d13 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1,10 +1,12 @@ package cmd import ( + "context" "fmt" "github.com/chainreactors/logs" "github.com/chainreactors/spray/internal" "github.com/jessevdk/go-flags" + "time" ) func Spray() { @@ -24,10 +26,13 @@ func Spray() { return } - err = runner.Prepare() + ctx, _ := context.WithTimeout(context.Background(), time.Duration(runner.Deadline)*time.Second) + + err = runner.Prepare(ctx) if err != nil { logs.Log.Errorf(err.Error()) return } - runner.Run() + + runner.Run(ctx) } diff --git a/internal/option.go b/internal/option.go index a309fe7..33cecd3 100644 --- a/internal/option.go +++ b/internal/option.go @@ -37,7 +37,7 @@ type InputOptions struct { } type OutputOptions struct { - Matches map[string]string `short:"m" long:"match" description:"String, "` + Matches map[string]string `long:"match" description:"String, "` Filters map[string]string `long:"filter" description:"String, "` Extracts []string `long:"extract" description:"String, "` OutputFile string `short:"f" description:"String, output filename"` @@ -74,6 +74,10 @@ func (opt *Option) PrepareRunner() (*Runner, error) { Mod: opt.Mod, Timeout: opt.Timeout, Probes: strings.Split(opt.OutputProbe, ","), + Deadline: opt.Deadline, + Offset: opt.Offset, + Limit: opt.Limit, + URLList: make(chan string), } err = pkg.LoadTemplates() @@ -91,10 +95,11 @@ func (opt *Option) PrepareRunner() (*Runner, error) { } // prepare url + var urls []string var file *os.File urlfrom := opt.URLFile if opt.URL != "" { - r.URLList = append(r.URLList, opt.URL) + urls = append(urls, opt.URL) urlfrom = "cmd" } else if opt.URLFile != "" { file, err = os.Open(opt.URLFile) @@ -111,13 +116,19 @@ func (opt *Option) PrepareRunner() (*Runner, error) { if err != nil { return nil, err } - r.URLList = strings.Split(string(content), "\n") + urls = strings.Split(string(content), "\n") } - for i, u := range r.URLList { - r.URLList[i] = strings.TrimSpace(u) + for i, u := range urls { + urls[i] = strings.TrimSpace(u) } - logs.Log.Importantf("load %d urls from %s", len(r.URLList), urlfrom) + logs.Log.Importantf("load %d urls from %s", len(urls), urlfrom) + go func() { + for _, u := range urls { + r.URLList <- u + } + close(r.URLList) + }() // prepare word dicts := make([][]string, len(opt.Dictionaries)) @@ -134,10 +145,10 @@ func (opt *Option) PrepareRunner() (*Runner, error) { for i, _ := range dicts { opt.Word += strconv.Itoa(i) } - opt.Word = "}" + opt.Word += "}" } - if opt.Suffixes == nil { + if opt.Suffixes != nil { dicts = append(dicts, opt.Suffixes) opt.Word += fmt.Sprintf("{?%d}", len(dicts)-1) } @@ -156,6 +167,9 @@ func (opt *Option) PrepareRunner() (*Runner, error) { if err != nil { return nil, err } + if r.Limit == 0 { + r.Limit = len(r.Wordlist) + } if opt.Uppercase { r.Fns = append(r.Fns, strings.ToUpper) @@ -184,7 +198,7 @@ func (opt *Option) PrepareRunner() (*Runner, error) { }) } - if opt.Replaces != nil { + if len(opt.Replaces) > 0 { r.Fns = append(r.Fns, func(s string) string { for k, v := range opt.Replaces { s = strings.Replace(s, k, v, -1) diff --git a/internal/runner.go b/internal/runner.go index 1d28afe..f99a4de 100644 --- a/internal/runner.go +++ b/internal/runner.go @@ -15,7 +15,7 @@ var BlackStatus = []int{400, 404, 410} var FuzzyStatus = []int{403, 500, 501, 502, 503} type Runner struct { - URLList []string + URLList chan string Wordlist []string Headers http.Header Fns []func(string) string @@ -30,9 +30,10 @@ type Runner struct { Progress *uiprogress.Progress Offset int Limit int + Deadline int } -func (r *Runner) Prepare() error { +func (r *Runner) Prepare(ctx context.Context) error { var err error CheckStatusCode = func(status int) bool { for _, black := range BlackStatus { @@ -44,7 +45,6 @@ func (r *Runner) Prepare() error { } r.OutputCh = make(chan *baseline, 100) - ctx := context.Background() r.Pools, err = ants.NewPoolWithFunc(r.PoolSize, func(i interface{}) { u := i.(string) @@ -75,7 +75,7 @@ func (r *Runner) Prepare() error { logs.Log.Error(err.Error()) return } - // todo pool 总超时时间 + pool.Run(ctx, r.Offset, r.Limit) r.poolwg.Done() }) @@ -87,12 +87,22 @@ func (r *Runner) Prepare() error { return nil } -func (r *Runner) Run() { - // todo pool 结束与并发控制 - for _, u := range r.URLList { - r.poolwg.Add(1) - r.Pools.Invoke(u) +func (r *Runner) Run(ctx context.Context) { +Loop: + for { + select { + case <-ctx.Done(): + logs.Log.Error("cancel with deadline") + break Loop + case u, ok := <-r.URLList: + if !ok { + break Loop + } + r.poolwg.Add(1) + r.Pools.Invoke(u) + } } + r.poolwg.Wait() for { if len(r.OutputCh) == 0 {