实现断点续传

This commit is contained in:
M09Ic 2022-12-02 19:59:15 +08:00
parent 023e316518
commit f9c5a71258
5 changed files with 126 additions and 72 deletions

View File

@ -19,8 +19,13 @@ func Spray() {
} }
return return
} }
var runner *internal.Runner
if option.ResumeFrom != "" {
runner, err = option.PrepareRunner()
} else {
runner, err = option.PrepareRunner()
}
runner, err := option.PrepareRunner()
if err != nil { if err != nil {
logs.Log.Errorf(err.Error()) logs.Log.Errorf(err.Error())
return return

View File

@ -25,6 +25,7 @@ type Option struct {
} }
type InputOptions struct { type InputOptions struct {
ResumeFrom string `short:"r" long:"resume-from"`
URL string `short:"u" long:"url" description:"String, input baseurl (separated by commas), e.g.: http://google.com, http://baidu.com"` URL string `short:"u" long:"url" description:"String, input baseurl (separated by commas), e.g.: http://google.com, http://baidu.com"`
URLFile string `short:"l" long:"list" description:"File, input filename"` URLFile string `short:"l" long:"list" description:"File, input filename"`
Offset int `long:"offset" description:"Int, wordlist offset"` Offset int `long:"offset" description:"Int, wordlist offset"`
@ -76,6 +77,7 @@ type MiscOptions struct {
Threads int `short:"t" long:"thread" default:"20" description:"Int, number of threads per pool (seconds)"` Threads int `short:"t" long:"thread" default:"20" description:"Int, number of threads per pool (seconds)"`
Debug bool `long:"debug" description:"Bool, output debug info"` Debug bool `long:"debug" description:"Bool, output debug info"`
Quiet bool `short:"q" long:"quiet" description:"Bool, Quiet"` Quiet bool `short:"q" long:"quiet" description:"Bool, Quiet"`
NoBar bool `long:"no-bar"`
Mod string `short:"m" long:"mod" default:"path" choice:"path" choice:"host" description:"String, path/host spray"` Mod string `short:"m" long:"mod" default:"path" choice:"path" choice:"host" description:"String, path/host spray"`
Client string `short:"c" long:"client" default:"auto" choice:"fast" choice:"standard" choice:"auto" description:"String, Client type"` Client string `short:"c" long:"client" default:"auto" choice:"fast" choice:"standard" choice:"auto" description:"String, Client type"`
} }
@ -94,8 +96,8 @@ func (opt *Option) PrepareRunner() (*Runner, error) {
Timeout: opt.Timeout, Timeout: opt.Timeout,
Deadline: opt.Deadline, Deadline: opt.Deadline,
Offset: opt.Offset, Offset: opt.Offset,
Limit: opt.Limit, Total: opt.Limit,
urlCh: make(chan string), taskCh: make(chan *Task),
OutputCh: make(chan *pkg.Baseline, 100), OutputCh: make(chan *pkg.Baseline, 100),
FuzzyCh: make(chan *pkg.Baseline, 100), FuzzyCh: make(chan *pkg.Baseline, 100),
Fuzzy: opt.Fuzzy, Fuzzy: opt.Fuzzy,
@ -124,13 +126,13 @@ func (opt *Option) PrepareRunner() (*Runner, error) {
if opt.Debug { if opt.Debug {
logs.Log.Level = logs.Debug logs.Log.Level = logs.Debug
} }
if !opt.Quiet { if opt.Quiet {
r.Progress.Start()
logs.Log.Writer = r.Progress.Bypass()
} else {
logs.Log.Quiet = true logs.Log.Quiet = true
} }
if opt.Quiet || opt.NoBar {
r.Progress.Start()
logs.Log.Writer = r.Progress.Bypass()
}
if opt.SimhashDistance != 0 { if opt.SimhashDistance != 0 {
pkg.Distance = uint8(opt.SimhashDistance) pkg.Distance = uint8(opt.SimhashDistance)
} }
@ -172,34 +174,6 @@ func (opt *Option) PrepareRunner() (*Runner, error) {
} }
} }
// prepare url
var urls []string
var file *os.File
urlfrom := opt.URLFile
if opt.URL != "" {
urls = append(urls, opt.URL)
urlfrom = "cmd"
} else if opt.URLFile != "" {
file, err = os.Open(opt.URLFile)
if err != nil {
return nil, err
}
} else if pkg.HasStdin() {
file = os.Stdin
urlfrom = "stdin"
}
if file != nil {
content, err := ioutil.ReadAll(file)
if err != nil {
return nil, err
}
urls = strings.Split(strings.TrimSpace(string(content)), "\n")
}
r.URLList = urls
logs.Log.Importantf("Loaded %d urls from %s", len(urls), urlfrom)
// prepare word // prepare word
dicts := make([][]string, len(opt.Dictionaries)) dicts := make([][]string, len(opt.Dictionaries))
for i, f := range opt.Dictionaries { for i, f := range opt.Dictionaries {
@ -244,19 +218,68 @@ func (opt *Option) PrepareRunner() (*Runner, error) {
return nil, err return nil, err
} }
logs.Log.Importantf("Parsed %d words by %s", len(r.Wordlist), opt.Word) logs.Log.Importantf("Parsed %d words by %s", len(r.Wordlist), opt.Word)
pkg.DefaultStatistor.Total = len(r.Wordlist) pkg.DefaultStatistor = pkg.Statistor{
pkg.DefaultStatistor.Word = opt.Word Word: opt.Word,
pkg.DefaultStatistor.Dictionaries = opt.Dictionaries WordCount: len(r.Wordlist),
Dictionaries: opt.Dictionaries,
Offset: opt.Offset,
}
if r.Limit == 0 { r.Total = len(r.Wordlist)
if r.CheckOnly { if opt.Limit != 0 {
r.Limit = len(r.URLList) if total := r.Offset + opt.Limit; total < r.Total {
} else { r.Total = total
r.Limit = len(r.Wordlist) }
}
// prepare task
var tasks []*Task
var taskfrom string
if opt.ResumeFrom != "" {
stats, err := pkg.ReadStatistors(opt.ResumeFrom)
if err != nil {
return nil, err
}
taskfrom = "resume " + opt.ResumeFrom
for _, stat := range stats {
tasks = append(tasks, &Task{baseUrl: stat.BaseUrl, offset: stat.Offset + stat.ReqNumber, total: r.Total})
} }
} else { } else {
r.Limit = r.Offset + opt.Limit var file *os.File
var urls []string
if opt.URL != "" {
urls = append(urls, opt.URL)
tasks = append(tasks, &Task{baseUrl: opt.URL, offset: opt.Offset, total: r.Total})
taskfrom = "cmd"
} else if opt.URLFile != "" {
file, err = os.Open(opt.URLFile)
if err != nil {
return nil, err
} }
taskfrom = opt.URLFile
} else if pkg.HasStdin() {
file = os.Stdin
taskfrom = "stdin"
}
if file != nil {
content, err := ioutil.ReadAll(file)
if err != nil {
return nil, err
}
urls := strings.Split(strings.TrimSpace(string(content)), "\n")
for _, u := range urls {
tasks = append(tasks, &Task{baseUrl: strings.TrimSpace(u), offset: opt.Offset, total: r.Total})
}
}
if opt.CheckOnly {
r.URLList = urls
r.Total = len(r.URLList)
}
}
r.Tasks = tasks
logs.Log.Importantf("Loaded %d urls from %s", len(tasks), taskfrom)
if opt.Uppercase { if opt.Uppercase {
r.Fns = append(r.Fns, strings.ToUpper) r.Fns = append(r.Fns, strings.ToUpper)
@ -361,7 +384,7 @@ func loadFileToSlice(filename string) ([]string, error) {
return nil, err return nil, err
} }
ss = strings.Split(string(content), "\n") ss = strings.Split(strings.TrimSpace(string(content)), "\n")
// 统一windows与linux的回车换行差异 // 统一windows与linux的回车换行差异
for i, word := range ss { for i, word := range ss {
@ -395,3 +418,9 @@ func IntsContains(s []int, e int) bool {
} }
return false return false
} }
type Task struct {
baseUrl string
offset int
total int
}

View File

@ -73,7 +73,7 @@ func NewPool(ctx context.Context, config *pkg.Config) (*Pool, error) {
} }
p, _ := ants.NewPoolWithFunc(config.Thread, func(i interface{}) { p, _ := ants.NewPoolWithFunc(config.Thread, func(i interface{}) {
pool.Statistor.ReqNumber++ pool.Statistor.Total++
unit := i.(*Unit) unit := i.(*Unit)
req, err := pool.genReq(unit.path) req, err := pool.genReq(unit.path)
if err != nil { if err != nil {
@ -287,7 +287,6 @@ func (p *Pool) addRedirect(bl *pkg.Baseline, reCount int) {
} }
func (p *Pool) Run(ctx context.Context, offset, limit int) { func (p *Pool) Run(ctx context.Context, offset, limit int) {
p.Statistor.Offset = offset
Loop: Loop:
for { for {
select { select {

View File

@ -11,7 +11,6 @@ import (
"github.com/gosuri/uiprogress" "github.com/gosuri/uiprogress"
"github.com/panjf2000/ants/v2" "github.com/panjf2000/ants/v2"
"net/http" "net/http"
"strings"
"sync" "sync"
"time" "time"
) )
@ -24,11 +23,12 @@ var (
) )
type Runner struct { type Runner struct {
urlCh chan string taskCh chan *Task
poolwg sync.WaitGroup poolwg sync.WaitGroup
bar *uiprogress.Bar bar *uiprogress.Bar
finished int finished int
Tasks []*Task
URLList []string URLList []string
Wordlist []string Wordlist []string
Headers http.Header Headers http.Header
@ -50,7 +50,7 @@ type Runner struct {
Force bool Force bool
Progress *uiprogress.Progress Progress *uiprogress.Progress
Offset int Offset int
Limit int Total int
Deadline int Deadline int
CheckPeriod int CheckPeriod int
ErrPeriod int ErrPeriod int
@ -84,6 +84,7 @@ func (r *Runner) PrepareConfig() *pkg.Config {
func (r *Runner) Prepare(ctx context.Context) error { func (r *Runner) Prepare(ctx context.Context) error {
var err error var err error
if r.CheckOnly { if r.CheckOnly {
// 仅check, 类似httpx
r.Pools, err = ants.NewPoolWithFunc(1, func(i interface{}) { r.Pools, err = ants.NewPoolWithFunc(1, func(i interface{}) {
config := r.PrepareConfig() config := r.PrepareConfig()
config.Wordlist = r.URLList config.Wordlist = r.URLList
@ -94,31 +95,31 @@ func (r *Runner) Prepare(ctx context.Context) error {
r.poolwg.Done() r.poolwg.Done()
return return
} }
pool.bar = pkg.NewBar("check", r.Limit-r.Offset, r.Progress) pool.bar = pkg.NewBar("check", r.Total-r.Offset, r.Progress)
pool.Run(ctx, r.Offset, r.Limit) pool.Run(ctx, r.Offset, r.Total)
r.poolwg.Done() r.poolwg.Done()
}) })
} else { } else {
go func() { go func() {
for _, u := range r.URLList { for _, t := range r.Tasks {
r.urlCh <- strings.TrimSpace(u) r.taskCh <- t
} }
close(r.urlCh) close(r.taskCh)
}() }()
if len(r.URLList) > 0 { if len(r.Tasks) > 0 {
r.bar = r.Progress.AddBar(len(r.URLList)) r.bar = r.Progress.AddBar(len(r.Tasks))
r.bar.PrependCompleted() r.bar.PrependCompleted()
r.bar.PrependFunc(func(b *uiprogress.Bar) string { r.bar.PrependFunc(func(b *uiprogress.Bar) string {
return fmt.Sprintf("total progressive: %d/%d ", r.finished, len(r.URLList)) return fmt.Sprintf("total progressive: %d/%d ", r.finished, len(r.Tasks))
}) })
r.bar.AppendElapsed() r.bar.AppendElapsed()
} }
r.Pools, err = ants.NewPoolWithFunc(r.PoolSize, func(i interface{}) { r.Pools, err = ants.NewPoolWithFunc(r.PoolSize, func(i interface{}) {
u := i.(string) t := i.(*Task)
config := r.PrepareConfig() config := r.PrepareConfig()
config.BaseURL = u config.BaseURL = t.baseUrl
config.Wordlist = r.Wordlist config.Wordlist = r.Wordlist
pool, err := NewPool(ctx, config) pool, err := NewPool(ctx, config)
if err != nil { if err != nil {
@ -128,7 +129,7 @@ func (r *Runner) Prepare(ctx context.Context) error {
return return
} }
pool.bar = pkg.NewBar(u, r.Limit-r.Offset, r.Progress) pool.bar = pkg.NewBar(config.BaseURL, t.total-t.offset, r.Progress)
err = pool.Init() err = pool.Init()
if err != nil { if err != nil {
logs.Log.Error(err.Error()) logs.Log.Error(err.Error())
@ -140,12 +141,11 @@ func (r *Runner) Prepare(ctx context.Context) error {
} }
} }
pool.Run(ctx, r.Offset, r.Limit) pool.Run(ctx, t.offset, t.total)
logs.Log.Important(pool.Statistor.String()) logs.Log.Important(pool.Statistor.String())
logs.Log.Important(pool.Statistor.Detail()) logs.Log.Important(pool.Statistor.Detail())
if r.StatFile != nil { if r.StatFile != nil {
r.StatFile.SafeWrite(pool.Statistor.Json() + "\n") r.StatFile.SafeWrite(pool.Statistor.Json())
r.StatFile.SafeSync() r.StatFile.SafeSync()
} }
r.Done() r.Done()
@ -167,18 +167,17 @@ Loop:
case <-ctx.Done(): case <-ctx.Done():
logs.Log.Error("cancel with deadline") logs.Log.Error("cancel with deadline")
break Loop break Loop
case u, ok := <-r.urlCh: case t, ok := <-r.taskCh:
if !ok { if !ok {
break Loop break Loop
} }
r.poolwg.Add(1) r.poolwg.Add(1)
r.Pools.Invoke(u) r.Pools.Invoke(t)
} }
} }
r.poolwg.Wait() r.poolwg.Wait()
//time.Sleep(100 * time.Millisecond) // 延迟100ms, 等所有数据处理完毕
time.Sleep(100) // 延迟100ms, 等所有数据处理完毕
for { for {
if len(r.OutputCh) == 0 { if len(r.OutputCh) == 0 {
close(r.OutputCh) close(r.OutputCh)
@ -192,7 +191,7 @@ Loop:
break break
} }
} }
time.Sleep(100) // 延迟100ms, 等所有数据处理完毕 time.Sleep(100 * time.Millisecond) // 延迟100ms, 等所有数据处理完毕
} }
func (r *Runner) RunWithCheck(ctx context.Context) { func (r *Runner) RunWithCheck(ctx context.Context) {
@ -225,7 +224,7 @@ Loop:
} }
} }
time.Sleep(100) // 延迟100ms, 等所有数据处理完毕 time.Sleep(100 * time.Millisecond) // 延迟100ms, 等所有数据处理完毕
} }
func (r *Runner) Done() { func (r *Runner) Done() {

View File

@ -1,8 +1,10 @@
package pkg package pkg
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -33,6 +35,7 @@ type Statistor struct {
Total int `json:"total"` Total int `json:"total"`
StartTime int64 `json:"start_time"` StartTime int64 `json:"start_time"`
EndTime int64 `json:"end_time"` EndTime int64 `json:"end_time"`
WordCount int `json:"word_count"`
Word string `json:"word"` Word string `json:"word"`
Dictionaries []string `json:"dictionaries"` Dictionaries []string `json:"dictionaries"`
} }
@ -71,5 +74,24 @@ func (stat *Statistor) Json() string {
if err != nil { if err != nil {
return err.Error() return err.Error()
} }
return string(content) return string(content) + "\n"
} }
func ReadStatistors(filename string) (Statistors, error) {
content, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
var stats Statistors
for _, line := range bytes.Split(content, []byte("\n")) {
var stat Statistor
err := json.Unmarshal(line, &stat)
if err != nil {
return nil, err
}
stats = append(stats, stat)
}
return stats, nil
}
type Statistors []Statistor