diff --git a/internal/pool.go b/internal/pool.go index c06a966..d447d2c 100644 --- a/internal/pool.go +++ b/internal/pool.go @@ -31,9 +31,15 @@ var ( ) func NewPool(ctx context.Context, config *pkg.Config) (*Pool, error) { + var u *url.URL + var err error + if u, err = url.Parse(config.BaseURL); err != nil { + return nil, err + } pctx, cancel := context.WithCancel(ctx) pool := &Pool{ Config: config, + url: u, ctx: pctx, cancel: cancel, client: ihttp.NewClient(config.Thread, 2, config.ClientType), @@ -124,6 +130,7 @@ func NewPool(ctx context.Context, config *pkg.Config) (*Pool, error) { type Pool struct { *pkg.Config + url *url.URL Statistor *pkg.Statistor client *ihttp.Client reqPool *ants.PoolWithFunc @@ -164,13 +171,16 @@ func (pool *Pool) Init() error { } logs.Log.Info("[baseline.random] " + pool.random.Format([]string{"status", "length", "spend", "title", "frame", "redirect"})) - if pool.random.RedirectURL != "" { - // 自定协议升级 - // 某些网站http会重定向到https, 如果发现随机目录出现这种情况, 则自定将baseurl升级为https - rurl, err := url.Parse(pool.random.RedirectURL) - if err == nil && rurl.Hostname() == pool.random.Url.Hostname() && pool.random.Url.Scheme == "http" && rurl.Scheme == "https" { - logs.Log.Infof("baseurl %s upgrade http to https", pool.BaseURL) - pool.BaseURL = strings.Replace(pool.BaseURL, "http", "https", 1) + // 某些网站http会重定向到https, 如果发现随机目录出现这种情况, 则自定将baseurl升级为https + if pool.url.Scheme == "http" { + if pool.index.RedirectURL != "" { + if err := pool.Upgrade(pool.index); err != nil { + return err + } + } else if pool.random.RedirectURL != "" { + if err := pool.Upgrade(pool.random); err != nil { + return err + } } } @@ -471,6 +481,22 @@ func CompareWithExpr(exp *vm.Program, params map[string]interface{}) bool { } } +func (pool *Pool) Upgrade(bl *pkg.Baseline) error { + rurl, err := url.Parse(bl.RedirectURL) + if err == nil && rurl.Hostname() == bl.Url.Hostname() && bl.Url.Scheme == "http" && rurl.Scheme == "https" { + logs.Log.Infof("baseurl %s upgrade http to https, reinit", pool.BaseURL) + pool.BaseURL = strings.Replace(pool.BaseURL, "http", "https", 1) + pool.url.Scheme = "https" + // 重新初始化 + err = pool.Init() + if err != nil { + return err + } + } + + return nil +} + func (pool *Pool) doRedirect(bl *pkg.Baseline, depth int) { defer pool.wg.Done() if depth >= MaxRedirect { @@ -596,11 +622,7 @@ func (pool *Pool) doActive() { func (pool *Pool) doBak() { defer pool.wg.Done() - u, err := url.Parse(pool.BaseURL) - if err != nil { - return - } - worder, err := words.NewWorderWithDsl("{?0}.{@bak_ext}", [][]string{pkg.BakGenerator(u.Host)}, nil) + worder, err := words.NewWorderWithDsl("{?0}.{@bak_ext}", [][]string{pkg.BakGenerator(pool.url.Host)}, nil) if err != nil { return }