diff --git a/drivers/all.go b/drivers/all.go index 224fb8dd..7fd2d567 100644 --- a/drivers/all.go +++ b/drivers/all.go @@ -58,6 +58,7 @@ import ( _ "github.com/alist-org/alist/v3/drivers/sftp" _ "github.com/alist-org/alist/v3/drivers/smb" _ "github.com/alist-org/alist/v3/drivers/teambition" + _ "github.com/alist-org/alist/v3/drivers/teldrive" _ "github.com/alist-org/alist/v3/drivers/terabox" _ "github.com/alist-org/alist/v3/drivers/thunder" _ "github.com/alist-org/alist/v3/drivers/thunder_browser" diff --git a/drivers/teldrive/driver.go b/drivers/teldrive/driver.go new file mode 100644 index 00000000..f7d2f5ca --- /dev/null +++ b/drivers/teldrive/driver.go @@ -0,0 +1,226 @@ +package teldrive + +import ( + "context" + "fmt" + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + "github.com/google/uuid" + "math" + "net/http" + "net/url" + "strings" +) + +type Teldrive struct { + model.Storage + Addition +} + +func (d *Teldrive) Config() driver.Config { + return config +} + +func (d *Teldrive) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Teldrive) Init(ctx context.Context) error { + // TODO login / refresh token + // op.MustSaveDriverStorage(d) + if d.Cookie == "" || !strings.HasPrefix(d.Cookie, "access_token=") { + return fmt.Errorf("cookie must start with 'access_token='") + } + if d.UploadConcurrency == 0 { + d.UploadConcurrency = 4 + } + if d.ChunkSize == 0 { + d.ChunkSize = 10 + } + if d.WebdavNative() { + d.WebProxy = true + } else { + d.WebProxy = false + } + + op.MustSaveDriverStorage(d) + return nil +} + +func (d *Teldrive) Drop(ctx context.Context) error { + return nil +} + +func (d *Teldrive) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + // TODO return the files list, required + // endpoint /api/files, params ->page order sort path + var listResp ListResp + params := url.Values{} + params.Set("path", dir.GetPath()) + //log.Info(dir.GetPath()) + pathname, err := utils.InjectQuery("/api/files", params) + if err != nil { + return nil, err + } + + err = d.request(http.MethodGet, pathname, nil, &listResp) + if err != nil { + return nil, err + } + + return utils.SliceConvert(listResp.Items, func(src Object) (model.Obj, error) { + return &model.Object{ + ID: src.ID, + Name: src.Name, + Size: func() int64 { + if src.Type == "folder" { + return 0 + } + return src.Size + }(), + IsFolder: src.Type == "folder", + Modified: src.UpdatedAt, + }, nil + }) +} + +func (d *Teldrive) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if !d.WebdavNative() { + if shareObj, err := d.getShareFileById(file.GetID()); err == nil && shareObj != nil { + return &model.Link{ + URL: d.Address + fmt.Sprintf("/api/shares/%s/files/%s/%s", shareObj.Id, file.GetID(), file.GetName()), + }, nil + } + if err := d.createShareFile(file.GetID()); err != nil { + return nil, err + } + shareObj, err := d.getShareFileById(file.GetID()) + if err != nil { + return nil, err + } + return &model.Link{ + URL: d.Address + fmt.Sprintf("/api/shares/%s/files/%s/%s", shareObj.Id, file.GetID(), file.GetName()), + }, nil + } + return &model.Link{ + URL: d.Address + "/api/files/" + file.GetID() + "/" + file.GetName(), + Header: http.Header{ + "Cookie": {d.Cookie}, + }, + }, nil +} + +func (d *Teldrive) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + return d.request(http.MethodPost, "/api/files/mkdir", func(req *resty.Request) { + req.SetBody(map[string]interface{}{ + "path": parentDir.GetPath() + "/" + dirName, + }) + }, nil) +} + +func (d *Teldrive) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + body := base.Json{ + "ids": []string{srcObj.GetID()}, + "destinationParent": dstDir.GetID(), + } + return d.request(http.MethodPost, "/api/files/move", func(req *resty.Request) { + req.SetBody(body) + }, nil) +} + +func (d *Teldrive) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + body := base.Json{ + "name": newName, + } + return d.request(http.MethodPatch, "/api/files/"+srcObj.GetID(), func(req *resty.Request) { + req.SetBody(body) + }, nil) +} + +func (d *Teldrive) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + copyConcurrentLimit := 4 + copyManager := NewCopyManager(ctx, copyConcurrentLimit, d) + copyManager.startWorkers() + copyManager.G.Go(func() error { + defer close(copyManager.TaskChan) + return copyManager.generateTasks(ctx, srcObj, dstDir) + }) + return copyManager.G.Wait() +} + +func (d *Teldrive) Remove(ctx context.Context, obj model.Obj) error { + body := base.Json{ + "ids": []string{obj.GetID()}, + } + return d.request(http.MethodPost, "/api/files/delete", func(req *resty.Request) { + req.SetBody(body) + }, nil) +} + +func (d *Teldrive) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { + fileId := uuid.New().String() + chunkSizeInMB := d.ChunkSize + chunkSize := chunkSizeInMB * 1024 * 1024 // Convert MB to bytes + totalSize := file.GetSize() + totalParts := int(math.Ceil(float64(totalSize) / float64(chunkSize))) + retryCount := 0 + maxRetried := 3 + p := driver.NewProgress(totalSize, up) + + // delete the upload task when finished or failed + defer func() { + _ = d.request(http.MethodDelete, "/api/uploads/"+fileId, nil, nil) + }() + + if obj, err := d.getFile(dstDir.GetPath(), file.GetName(), file.IsDir()); err == nil { + if err = d.Remove(ctx, obj); err != nil { + return err + } + } + // start the upload process + if err := d.request(http.MethodGet, "/api/uploads/"+fileId, nil, nil); err != nil { + return err + } + if totalSize == 0 { + return d.touch(file.GetName(), dstDir.GetPath()) + } + + if totalParts <= 1 { + return d.doSingleUpload(ctx, dstDir, file, p, retryCount, maxRetried, totalParts, fileId) + } + + return d.doMultiUpload(ctx, dstDir, file, p, maxRetried, totalParts, chunkSize, fileId) +} + +func (d *Teldrive) GetArchiveMeta(ctx context.Context, obj model.Obj, args model.ArchiveArgs) (model.ArchiveMeta, error) { + // TODO get archive file meta-info, return errs.NotImplement to use an internal archive tool, optional + return nil, errs.NotImplement +} + +func (d *Teldrive) ListArchive(ctx context.Context, obj model.Obj, args model.ArchiveInnerArgs) ([]model.Obj, error) { + // TODO list args.InnerPath in the archive obj, return errs.NotImplement to use an internal archive tool, optional + return nil, errs.NotImplement +} + +func (d *Teldrive) Extract(ctx context.Context, obj model.Obj, args model.ArchiveInnerArgs) (*model.Link, error) { + // TODO return link of file args.InnerPath in the archive obj, return errs.NotImplement to use an internal archive tool, optional + return nil, errs.NotImplement +} + +func (d *Teldrive) ArchiveDecompress(ctx context.Context, srcObj, dstDir model.Obj, args model.ArchiveDecompressArgs) ([]model.Obj, error) { + // TODO extract args.InnerPath path in the archive srcObj to the dstDir location, optional + // a folder with the same name as the archive file needs to be created to store the extracted results if args.PutIntoNewDir + // return errs.NotImplement to use an internal archive tool + return nil, errs.NotImplement +} + +//func (d *Teldrive) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +var _ driver.Driver = (*Teldrive)(nil) diff --git a/drivers/teldrive/meta.go b/drivers/teldrive/meta.go new file mode 100644 index 00000000..7f7f77b7 --- /dev/null +++ b/drivers/teldrive/meta.go @@ -0,0 +1,27 @@ +package teldrive + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + // Usually one of two + driver.RootPath + // define other + Address string `json:"url" required:"true"` + ChunkSize int64 `json:"chunk_size" type:"number" default:"4" help:"Chunk size in MiB"` + Cookie string `json:"cookie" type:"string" required:"true" help:"access_token=xxx"` + UploadConcurrency int64 `json:"upload_concurrency" type:"number" default:"4" help:"Concurrency upload requests"` +} + +var config = driver.Config{ + Name: "Teldrive", + DefaultRoot: "/", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Teldrive{} + }) +} diff --git a/drivers/teldrive/types.go b/drivers/teldrive/types.go new file mode 100644 index 00000000..0d6b0246 --- /dev/null +++ b/drivers/teldrive/types.go @@ -0,0 +1,77 @@ +package teldrive + +import ( + "context" + "github.com/alist-org/alist/v3/internal/model" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" + "time" +) + +type ErrResp struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type Object struct { + ID string `json:"id"` + Name string `json:"name"` + Type string `json:"type"` + MimeType string `json:"mimeType"` + Category string `json:"category,omitempty"` + ParentId string `json:"parentId"` + Size int64 `json:"size"` + Encrypted bool `json:"encrypted"` + UpdatedAt time.Time `json:"updatedAt"` +} + +type ListResp struct { + Items []Object `json:"items"` + Meta struct { + Count int `json:"count"` + TotalPages int `json:"totalPages"` + CurrentPage int `json:"currentPage"` + } `json:"meta"` +} + +type FilePart struct { + Name string `json:"name"` + PartId int `json:"partId"` + PartNo int `json:"partNo"` + ChannelId int `json:"channelId"` + Size int `json:"size"` + Encrypted bool `json:"encrypted"` + Salt string `json:"salt"` +} + +type chunkTask struct { + data []byte + chunkIdx int + fileName string +} + +type CopyManager struct { + TaskChan chan CopyTask + Sem *semaphore.Weighted + G *errgroup.Group + Ctx context.Context + d *Teldrive +} + +type CopyTask struct { + SrcObj model.Obj + DstDir model.Obj +} + +type CustomProxy struct { + model.Proxy +} + +type ShareObj struct { + Id string `json:"id"` + Protected bool `json:"protected"` + UserId int `json:"userId"` + Type string `json:"type"` + Name string `json:"name"` + ExpiresAt time.Time `json:"expiresAt"` +} diff --git a/drivers/teldrive/util.go b/drivers/teldrive/util.go new file mode 100644 index 00000000..d757e75f --- /dev/null +++ b/drivers/teldrive/util.go @@ -0,0 +1,590 @@ +package teldrive + +import ( + "bytes" + "fmt" + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + "github.com/pkg/errors" + "golang.org/x/net/context" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" + "io" + "net/http" + "sort" + "strconv" + "time" +) + +// do others that not defined in Driver interface + +func (d *Teldrive) request(method string, pathname string, callback base.ReqCallback, resp interface{}) error { + url := d.Address + pathname + req := base.RestyClient.R() + req.SetHeader("Cookie", d.Cookie) + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + var e ErrResp + req.SetError(&e) + _req, err := req.Execute(method, url) + if err != nil { + return err + } + + if _req.IsError() { + return &e + } + return nil +} + +func (d *Teldrive) getFile(path, name string, isFolder bool) (model.Obj, error) { + resp := &ListResp{} + err := d.request(http.MethodGet, "/api/files", func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "path": path, + "name": name, + "type": func() string { + if isFolder { + return "folder" + } + return "file" + }(), + "operation": "find", + }) + }, resp) + if err != nil { + return nil, err + } + if len(resp.Items) == 0 { + return nil, fmt.Errorf("file not found: %s/%s", path, name) + } + obj := resp.Items[0] + return &model.Object{ + ID: obj.ID, + Name: obj.Name, + Size: obj.Size, + IsFolder: obj.Type == "folder", + }, err +} + +func (err *ErrResp) Error() string { + if err == nil { + return "" + } + + return fmt.Sprintf("[Teldrive] message:%s Error code:%d", err.Message, err.Code) +} + +// create empty file +func (d *Teldrive) touch(name, path string) error { + uploadBody := base.Json{ + "name": name, + "type": "file", + "path": path, + } + if err := d.request(http.MethodPost, "/api/files", func(req *resty.Request) { + req.SetBody(uploadBody) + }, nil); err != nil { + return err + } + + return nil +} + +func (d *Teldrive) createFileOnUploadSuccess(name, id, path string, uploadedFileParts []FilePart, totalSize int64) error { + remoteFileParts, err := d.getFilePart(id) + if err != nil { + return err + } + // check if the uploaded file parts match the remote file parts + if len(remoteFileParts) != len(uploadedFileParts) { + return fmt.Errorf("[Teldrive] file parts count mismatch: expected %d, got %d", len(uploadedFileParts), len(remoteFileParts)) + } + formatParts := make([]base.Json, 0) + for _, p := range remoteFileParts { + formatParts = append(formatParts, base.Json{ + "id": p.PartId, + "salt": p.Salt, + }) + } + uploadBody := base.Json{ + "name": name, + "type": "file", + "path": path, + "parts": formatParts, + "size": totalSize, + } + // create file here + if err := d.request(http.MethodPost, "/api/files", func(req *resty.Request) { + req.SetBody(uploadBody) + }, nil); err != nil { + return err + } + + return nil +} + +func (d *Teldrive) checkFilePartExist(fileId string, partId int) (FilePart, error) { + var uploadedParts []FilePart + var filePart FilePart + + if err := d.request(http.MethodGet, "/api/uploads/"+fileId, nil, &uploadedParts); err != nil { + return filePart, err + } + + for _, part := range uploadedParts { + if part.PartId == partId { + return part, nil + } + } + + return filePart, nil +} + +func (d *Teldrive) getFilePart(fileId string) ([]FilePart, error) { + var uploadedParts []FilePart + if err := d.request(http.MethodGet, "/api/uploads/"+fileId, nil, &uploadedParts); err != nil { + return nil, err + } + + return uploadedParts, nil +} + +func (d *Teldrive) singleUploadRequest(fileId string, callback base.ReqCallback, resp interface{}) error { + url := d.Address + "/api/uploads/" + fileId + client := resty.New().SetTimeout(0) + + ctx := context.Background() + + req := client.R(). + SetContext(ctx) + req.SetHeader("Cookie", d.Cookie) + req.SetHeader("Content-Type", "application/octet-stream") + req.SetContentLength(true) + req.AddRetryCondition(func(r *resty.Response, err error) bool { + return false + }) + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + var e ErrResp + req.SetError(&e) + _req, err := req.Execute(http.MethodPost, url) + if err != nil { + return err + } + + if _req.IsError() { + return &e + } + return nil +} + +func (d *Teldrive) doSingleUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, p *driver.Progress, + retryCount, maxRetried, totalParts int, fileId string) error { + + chunkIdx := 1 + totalSize := file.GetSize() + var fileParts []FilePart + for p.Done < p.Total { + if utils.IsCanceled(ctx) { + return ctx.Err() + } + // only one chunk, so we can use the whole file + byteData := make([]byte, totalSize) + _, err := io.ReadFull(file, byteData) + if err != io.EOF && err != nil { + return err + } + filePart := &FilePart{} + // be sure the file is uploaded, and break loop if success + for { + if err := d.singleUploadRequest(fileId, func(req *resty.Request) { + uploadParams := map[string]string{ + "partName": func() string { + digits := len(fmt.Sprintf("%d", totalParts)) + return file.GetName() + fmt.Sprintf("%0*d", digits, chunkIdx) + }(), + "partNo": strconv.Itoa(chunkIdx), + "fileName": file.GetName(), + } + req.SetQueryParams(uploadParams) + req.SetBody(driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) + req.SetHeader("Content-Length", strconv.Itoa(len(byteData))) + }, filePart); err != nil { + if retryCount >= maxRetried { + utils.Log.Errorf("[Teldrive] upload failed after %d retries: %s", maxRetried, err.Error()) + return err + } + if errors.Is(err, context.DeadlineExceeded) { + continue + } + retryCount++ + errorStr := fmt.Sprintf("[Teldrive] upload error: %v, retrying %d times", err, retryCount) + utils.Log.Errorf(errorStr) + time.Sleep(time.Duration(retryCount<<1) * time.Second) // Exponential backoff: 2, 4, 8, 16, ... + continue + } + break + } + if filePart.Name != "" { + fileParts = append(fileParts, *filePart) + retryCount = 0 + _, _ = p.Write(byteData) + chunkIdx++ + } + + } + + return d.createFileOnUploadSuccess(file.GetName(), fileId, dstDir.GetPath(), fileParts, totalSize) +} + +func (d *Teldrive) doMultiUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, p *driver.Progress, + maxRetried, totalParts int, chunkSize int64, fileId string) error { + concurrent := d.UploadConcurrency + g, ctx := errgroup.WithContext(ctx) + sem := semaphore.NewWeighted(int64(concurrent)) + chunkChan := make(chan chunkTask, concurrent*2) + resultChan := make(chan FilePart, concurrent) + totalSize := file.GetSize() + g.Go(func() error { + defer close(chunkChan) + + chunkIdx := 1 + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if p.Done >= p.Total { + break + } + + byteData := make([]byte, chunkSize) + n, err := io.ReadFull(file, byteData) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + if n > 0 { + // handle the last chuck + byteData = byteData[:n] + task := chunkTask{ + data: byteData, + chunkIdx: chunkIdx, + fileName: file.GetName(), + } + select { + case chunkChan <- task: + case <-ctx.Done(): + return ctx.Err() + } + } + break + } + return fmt.Errorf("read file error: %w", err) + } + + if _, err := p.Write(byteData); err != nil { + return fmt.Errorf("progress update error: %w", err) + } + + task := chunkTask{ + data: byteData, + chunkIdx: chunkIdx, + fileName: file.GetName(), + } + + select { + case chunkChan <- task: + chunkIdx++ + case <-ctx.Done(): + return ctx.Err() + } + } + return nil + }) + for i := 0; i < int(concurrent); i++ { + g.Go(func() error { + for task := range chunkChan { + if err := sem.Acquire(ctx, 1); err != nil { + return err + } + + filePart, err := d.uploadSingleChunk(ctx, fileId, task, totalParts, maxRetried) + sem.Release(1) + + if err != nil { + return fmt.Errorf("upload chunk %d failed: %w", task.chunkIdx, err) + } + + select { + case resultChan <- *filePart: + case <-ctx.Done(): + return ctx.Err() + } + } + return nil + }) + } + var fileParts []FilePart + var collectErr error + collectDone := make(chan struct{}) + + go func() { + defer close(collectDone) + fileParts = make([]FilePart, 0, totalParts) + + done := make(chan error, 1) + go func() { + done <- g.Wait() + close(resultChan) + }() + + for { + select { + case filePart, ok := <-resultChan: + if !ok { + collectErr = <-done + return + } + fileParts = append(fileParts, filePart) + case err := <-done: + collectErr = err + return + } + } + }() + + <-collectDone + + if collectErr != nil { + return fmt.Errorf("multi-upload failed: %w", collectErr) + } + sort.Slice(fileParts, func(i, j int) bool { + return fileParts[i].PartNo < fileParts[j].PartNo + }) + + return d.createFileOnUploadSuccess(file.GetName(), fileId, dstDir.GetPath(), fileParts, totalSize) +} + +func (d *Teldrive) uploadSingleChunk(ctx context.Context, fileId string, task chunkTask, totalParts, maxRetried int) (*FilePart, error) { + filePart := &FilePart{} + retryCount := 0 + + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + if existingPart, err := d.checkFilePartExist(fileId, task.chunkIdx); err == nil && existingPart.Name != "" { + return &existingPart, nil + } + + err := d.singleUploadRequest(fileId, func(req *resty.Request) { + uploadParams := map[string]string{ + "partName": func() string { + digits := len(fmt.Sprintf("%d", totalParts)) + return task.fileName + fmt.Sprintf("%0*d", digits, task.chunkIdx) + }(), + "partNo": strconv.Itoa(task.chunkIdx), + "fileName": task.fileName, + } + req.SetQueryParams(uploadParams) + req.SetBody(driver.NewLimitedUploadStream(ctx, bytes.NewReader(task.data))) + req.SetHeader("Content-Length", strconv.Itoa(len(task.data))) + }, filePart) + + if err == nil { + return filePart, nil + } + + if retryCount >= maxRetried { + return nil, fmt.Errorf("upload failed after %d retries: %w", maxRetried, err) + } + + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + continue + } + + retryCount++ + utils.Log.Errorf("[Teldrive] upload error: %v, retrying %d times", err, retryCount) + + backoffDuration := time.Duration(retryCount*retryCount) * time.Second + if backoffDuration > 30*time.Second { + backoffDuration = 30 * time.Second + } + + select { + case <-time.After(backoffDuration): + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +func (d *Teldrive) createShareFile(fileId string) error { + var errResp ErrResp + if err := d.request(http.MethodPost, "/api/files/"+fileId+"/share", func(req *resty.Request) { + req.SetBody(base.Json{ + "expiresAt": getDateTime(), + }) + }, &errResp); err != nil { + return err + } + + if errResp.Message != "" { + return &errResp + } + + return nil +} + +func (d *Teldrive) getShareFileById(fileId string) (*ShareObj, error) { + var shareObj ShareObj + if err := d.request(http.MethodGet, "/api/files/"+fileId+"/share", nil, &shareObj); err != nil { + return nil, err + } + + return &shareObj, nil +} + +func getDateTime() string { + now := time.Now().UTC() + formattedWithMs := now.Add(time.Hour * 1).Format("2006-01-02T15:04:05.000Z") + return formattedWithMs +} + +func NewCopyManager(ctx context.Context, concurrent int, d *Teldrive) *CopyManager { + g, ctx := errgroup.WithContext(ctx) + + return &CopyManager{ + TaskChan: make(chan CopyTask, concurrent*2), + Sem: semaphore.NewWeighted(int64(concurrent)), + G: g, + Ctx: ctx, + d: d, + } +} + +func (cm *CopyManager) startWorkers() { + workerCount := cap(cm.TaskChan) / 2 + for i := 0; i < workerCount; i++ { + cm.G.Go(func() error { + return cm.worker() + }) + } +} + +func (cm *CopyManager) worker() error { + for { + select { + case task, ok := <-cm.TaskChan: + if !ok { + return nil + } + + if err := cm.Sem.Acquire(cm.Ctx, 1); err != nil { + return err + } + + var err error + + err = cm.processFile(task) + + cm.Sem.Release(1) + + if err != nil { + return fmt.Errorf("task processing failed: %w", err) + } + + case <-cm.Ctx.Done(): + return cm.Ctx.Err() + } + } +} + +func (cm *CopyManager) generateTasks(ctx context.Context, srcObj, dstDir model.Obj) error { + if srcObj.IsDir() { + return cm.generateFolderTasks(ctx, srcObj, dstDir) + } else { + // add single file task directly + select { + case cm.TaskChan <- CopyTask{SrcObj: srcObj, DstDir: dstDir}: + return nil + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func (cm *CopyManager) generateFolderTasks(ctx context.Context, srcDir, dstDir model.Obj) error { + objs, err := cm.d.List(ctx, srcDir, model.ListArgs{}) + if err != nil { + return fmt.Errorf("failed to list directory %s: %w", srcDir.GetPath(), err) + } + + err = cm.d.MakeDir(cm.Ctx, dstDir, srcDir.GetName()) + if err != nil || len(objs) == 0 { + return err + } + newDstDir := &model.Object{ + ID: dstDir.GetID(), + Path: dstDir.GetPath() + "/" + srcDir.GetName(), + Name: srcDir.GetName(), + IsFolder: true, + } + + for _, file := range objs { + if utils.IsCanceled(ctx) { + return ctx.Err() + } + + srcFile := &model.Object{ + ID: file.GetID(), + Path: srcDir.GetPath() + "/" + file.GetName(), + Name: file.GetName(), + IsFolder: file.IsDir(), + } + + // 递归生成任务 + if err := cm.generateTasks(ctx, srcFile, newDstDir); err != nil { + return err + } + } + + return nil +} + +func (cm *CopyManager) processFile(task CopyTask) error { + return cm.copySingleFile(cm.Ctx, task.SrcObj, task.DstDir) +} + +func (cm *CopyManager) copySingleFile(ctx context.Context, srcObj, dstDir model.Obj) error { + // `override copy mode` should delete the existing file + if obj, err := cm.d.getFile(dstDir.GetPath(), srcObj.GetName(), srcObj.IsDir()); err == nil { + if err := cm.d.Remove(ctx, obj); err != nil { + return fmt.Errorf("failed to remove existing file: %w", err) + } + } + + // Do copy + return cm.d.request(http.MethodPost, "/api/files/"+srcObj.GetID()+"/copy", func(req *resty.Request) { + req.SetBody(base.Json{ + "newName": srcObj.GetName(), + "destination": dstDir.GetPath(), + }) + }, nil) +}