diff --git a/drivers/local/driver.go b/drivers/local/driver.go index 10d7baa0..0358c9d1 100644 --- a/drivers/local/driver.go +++ b/drivers/local/driver.go @@ -21,8 +21,7 @@ func (d Driver) Config() driver.Config { func (d *Driver) Init(ctx context.Context, account model.Account) error { d.Account = account - addition := d.Account.Addition - err := utils.Json.UnmarshalFromString(addition, &d.Addition) + err := utils.Json.UnmarshalFromString(d.Account.Addition, &d.Addition) if err != nil { return errors.Wrap(err, "error while unmarshal addition") } @@ -32,7 +31,7 @@ func (d *Driver) Init(ctx context.Context, account model.Account) error { } else { d.SetStatus("OK") } - operations.SaveDriverAccount(d) + operations.MustSaveDriverAccount(d) return err } @@ -79,7 +78,7 @@ func (d *Driver) Remove(ctx context.Context, obj model.Obj) error { panic("implement me") } -func (d *Driver) Put(ctx context.Context, parentDir model.Obj, stream model.FileStreamer) error { +func (d *Driver) Put(ctx context.Context, parentDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { //TODO implement me panic("implement me") } diff --git a/internal/driver/driver.go b/internal/driver/driver.go index 1962ccba..dc69bb09 100644 --- a/internal/driver/driver.go +++ b/internal/driver/driver.go @@ -50,5 +50,7 @@ type Writer interface { // Remove remove `object` Remove(ctx context.Context, obj model.Obj) error // Put upload `stream` to `parentDir` - Put(ctx context.Context, parentDir model.Obj, stream model.FileStreamer) error + Put(ctx context.Context, parentDir model.Obj, stream model.FileStreamer, up UpdateProgress) error } + +type UpdateProgress func(percentage float64) diff --git a/internal/fs/copy.go b/internal/fs/copy.go index c569e377..bd7c3fa3 100644 --- a/internal/fs/copy.go +++ b/internal/fs/copy.go @@ -2,34 +2,55 @@ package fs import ( "context" + "fmt" + "github.com/alist-org/alist/v3/pkg/task" + "github.com/alist-org/alist/v3/pkg/utils" stdpath "path" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/operations" - "github.com/alist-org/alist/v3/internal/task" "github.com/pkg/errors" ) var copyTaskManager = task.NewTaskManager() -func CopyBetween2Accounts(ctx context.Context, srcAccount, dstAccount driver.Driver, srcPath, dstPath string) error { - srcFile, err := operations.Get(ctx, srcAccount, srcPath) +func CopyBetween2Accounts(ctx context.Context, srcAccount, dstAccount driver.Driver, srcPath, dstPath string, setStatus func(status string)) error { + setStatus("getting src object") + srcObj, err := operations.Get(ctx, srcAccount, srcPath) if err != nil { return errors.WithMessagef(err, "failed get src [%s] file", srcPath) } - if srcFile.IsDir() { + if srcObj.IsDir() { + setStatus("src object is dir, listing files") files, err := operations.List(ctx, srcAccount, srcPath) if err != nil { return errors.WithMessagef(err, "failed list src [%s] files", srcPath) } for _, file := range files { + if utils.IsCanceled(ctx) { + return nil + } srcFilePath := stdpath.Join(srcPath, file.GetName()) dstFilePath := stdpath.Join(dstPath, file.GetName()) - if err := CopyBetween2Accounts(ctx, srcAccount, dstAccount, srcFilePath, dstFilePath); err != nil { - return errors.WithMessagef(err, "failed copy file [%s] to [%s]", srcFilePath, dstFilePath) - } + copyTaskManager.Add(fmt.Sprintf("copy %s to %s", srcFilePath, dstFilePath), func(task *task.Task) error { + return CopyBetween2Accounts(ctx, srcAccount, dstAccount, srcFilePath, dstFilePath, task.SetStatus) + }) } + } else { + copyTaskManager.Add(fmt.Sprintf("copy %s to %s", srcPath, dstPath), func(task *task.Task) error { + return CopyFileBetween2Accounts(task.Ctx, srcAccount, dstAccount, srcPath, dstPath, func(percentage float64) { + task.SetStatus(fmt.Sprintf("uploading: %2.f%", percentage)) + }) + }) + } + return nil +} + +func CopyFileBetween2Accounts(ctx context.Context, srcAccount, dstAccount driver.Driver, srcPath, dstPath string, up driver.UpdateProgress) error { + srcFile, err := operations.Get(ctx, srcAccount, srcPath) + if err != nil { + return errors.WithMessagef(err, "failed get src [%s] file", srcPath) } link, err := operations.Link(ctx, srcAccount, srcPath, model.LinkArgs{}) if err != nil { @@ -39,6 +60,5 @@ func CopyBetween2Accounts(ctx context.Context, srcAccount, dstAccount driver.Dri if err != nil { return errors.WithMessagef(err, "failed get [%s] stream", srcPath) } - // TODO add as task - return operations.Put(ctx, dstAccount, dstPath, stream) + return operations.Put(ctx, dstAccount, dstPath, stream, up) } diff --git a/internal/fs/write.go b/internal/fs/write.go index f476f759..412f5736 100644 --- a/internal/fs/write.go +++ b/internal/fs/write.go @@ -3,11 +3,10 @@ package fs import ( "context" "fmt" - "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/operations" - "github.com/alist-org/alist/v3/internal/task" + "github.com/alist-org/alist/v3/pkg/task" "github.com/pkg/errors" ) @@ -49,7 +48,7 @@ func Copy(ctx context.Context, account driver.Driver, srcPath, dstPath string) ( if err != nil { return false, errors.WithMessage(err, "failed get src account") } - dstAccount, dstActualPath, err := operations.GetAccountAndActualPath(srcPath) + dstAccount, dstActualPath, err := operations.GetAccountAndActualPath(dstPath) if err != nil { return false, errors.WithMessage(err, "failed get dst account") } @@ -60,7 +59,7 @@ func Copy(ctx context.Context, account driver.Driver, srcPath, dstPath string) ( // not in an account // TODO add status set callback to put copyTaskManager.Add(fmt.Sprintf("copy %s to %s", srcActualPath, dstActualPath), func(task *task.Task) error { - return CopyBetween2Accounts(context.TODO(), srcAccount, dstAccount, srcActualPath, dstActualPath) + return CopyBetween2Accounts(task.Ctx, srcAccount, dstAccount, srcActualPath, dstActualPath, task.SetStatus) }) return true, nil } @@ -73,10 +72,11 @@ func Remove(ctx context.Context, account driver.Driver, path string) error { return operations.Remove(ctx, account, actualPath) } +// Put add as a put task func Put(ctx context.Context, account driver.Driver, parentPath string, file model.FileStreamer) error { account, actualParentPath, err := operations.GetAccountAndActualPath(parentPath) if err != nil { return errors.WithMessage(err, "failed get account") } - return operations.Put(ctx, account, actualParentPath, file) + return operations.Put(ctx, account, actualParentPath, file, nil) } diff --git a/internal/operations/account.go b/internal/operations/account.go index e897530d..f73f998e 100644 --- a/internal/operations/account.go +++ b/internal/operations/account.go @@ -2,6 +2,7 @@ package operations import ( "context" + log "github.com/sirupsen/logrus" "sort" "strings" "time" @@ -85,8 +86,15 @@ func UpdateAccount(ctx context.Context, account model.Account) error { return nil } -// SaveDriverAccount call from specific driver -func SaveDriverAccount(driver driver.Driver) error { +// MustSaveDriverAccount call from specific driver +func MustSaveDriverAccount(driver driver.Driver) { + err := saveDriverAccount(driver) + if err != nil { + log.Errorf("failed save driver account: %s", err) + } +} + +func saveDriverAccount(driver driver.Driver) error { account := driver.GetAccount() addition := driver.GetAddition() bytes, err := utils.Json.Marshal(addition) diff --git a/internal/operations/fs.go b/internal/operations/fs.go index d3ccf01e..ac6803aa 100644 --- a/internal/operations/fs.go +++ b/internal/operations/fs.go @@ -182,7 +182,7 @@ func Remove(ctx context.Context, account driver.Driver, path string) error { return account.Remove(ctx, obj) } -func Put(ctx context.Context, account driver.Driver, parentPath string, file model.FileStreamer) error { +func Put(ctx context.Context, account driver.Driver, parentPath string, file model.FileStreamer, up driver.UpdateProgress) error { err := MakeDir(ctx, account, parentPath) if err != nil { return errors.WithMessagef(err, "failed to make dir [%s]", parentPath) @@ -192,5 +192,9 @@ func Put(ctx context.Context, account driver.Driver, parentPath string, file mod if err != nil { return errors.WithMessagef(err, "failed to get dir [%s]", parentPath) } - return account.Put(ctx, parentDir, file) + // if up is nil, set a default to prevent panic + if up == nil { + up = func(p float64) {} + } + return account.Put(ctx, parentDir, file, up) } diff --git a/internal/task/task.go b/internal/task/task.go deleted file mode 100644 index e90fbfe8..00000000 --- a/internal/task/task.go +++ /dev/null @@ -1,36 +0,0 @@ -// manage task, such as file upload, file copy between accounts, offline download, etc. -package task - -type Func func(task *Task) error - -var ( - PENDING = "pending" - RUNNING = "running" - FINISHED = "finished" -) - -type Task struct { - ID int64 - Name string - Status string - Error error - Func Func -} - -func NewTask(name string, func_ Func) *Task { - return &Task{ - Name: name, - Status: PENDING, - Func: func_, - } -} - -func (t *Task) SetStatus(status string) { - t.Status = status -} - -func (t *Task) Run() { - t.Status = RUNNING - t.Error = t.Func(t) - t.Status = FINISHED -} diff --git a/internal/task/manager.go b/pkg/task/manager.go similarity index 60% rename from internal/task/manager.go rename to pkg/task/manager.go index 834fcd13..64179687 100644 --- a/internal/task/manager.go +++ b/pkg/task/manager.go @@ -6,37 +6,37 @@ import ( "github.com/alist-org/alist/v3/pkg/generic_sync" ) -func NewTaskManager() *TaskManager { - return &TaskManager{ +func NewTaskManager() *Manager { + return &Manager{ tasks: generic_sync.MapOf[int64, *Task]{}, curID: 0, } } -type TaskManager struct { +type Manager struct { curID int64 tasks generic_sync.MapOf[int64, *Task] } -func (tm *TaskManager) AddTask(task *Task) { +func (tm *Manager) AddTask(task *Task) { task.ID = tm.curID atomic.AddInt64(&tm.curID, 1) tm.tasks.Store(task.ID, task) } -func (tm *TaskManager) GetAll() []*Task { +func (tm *Manager) GetAll() []*Task { return tm.tasks.Values() } -func (tm *TaskManager) Get(id int64) (*Task, bool) { +func (tm *Manager) Get(id int64) (*Task, bool) { return tm.tasks.Load(id) } -func (tm *TaskManager) Remove(id int64) { +func (tm *Manager) Remove(id int64) { tm.tasks.Delete(id) } -func (tm *TaskManager) RemoveFinished() { +func (tm *Manager) RemoveFinished() { tasks := tm.GetAll() for _, task := range tasks { if task.Status == FINISHED { @@ -45,7 +45,7 @@ func (tm *TaskManager) RemoveFinished() { } } -func (tm *TaskManager) RemoveError() { +func (tm *Manager) RemoveError() { tasks := tm.GetAll() for _, task := range tasks { if task.Error != nil { @@ -54,8 +54,8 @@ func (tm *TaskManager) RemoveError() { } } -func (tm *TaskManager) Add(name string, f Func) { - task := NewTask(name, f) +func (tm *Manager) Add(name string, f Func) { + task := newTask(name, f) tm.AddTask(task) go task.Run() } diff --git a/pkg/task/task.go b/pkg/task/task.go new file mode 100644 index 00000000..7084283f --- /dev/null +++ b/pkg/task/task.go @@ -0,0 +1,64 @@ +// Package task manage task, such as file upload, file copy between accounts, offline download, etc. +package task + +import ( + "context" + "github.com/pkg/errors" +) + +var ( + PENDING = "pending" + RUNNING = "running" + FINISHED = "finished" + CANCELING = "canceling" + CANCELED = "canceled" +) + +type Func func(task *Task) error + +type Task struct { + ID int64 + Name string + Status string + Error error + Func Func + Ctx context.Context + cancel context.CancelFunc +} + +func newTask(name string, func_ Func) *Task { + ctx, cancel := context.WithCancel(context.Background()) + return &Task{ + Name: name, + Status: PENDING, + Func: func_, + Ctx: ctx, + cancel: cancel, + } +} + +func (t *Task) SetStatus(status string) { + t.Status = status +} + +func (t *Task) Run() { + t.Status = RUNNING + t.Error = t.Func(t) + if errors.Is(t.Ctx.Err(), context.Canceled) { + t.Status = CANCELED + } else { + t.Status = FINISHED + } +} + +func (t *Task) Retry() { + t.Run() +} + +func (t *Task) Cancel() { + if t.cancel != nil { + t.cancel() + } + // maybe can't cancel + t.Status = CANCELING +} diff --git a/pkg/utils/ctx.go b/pkg/utils/ctx.go new file mode 100644 index 00000000..d2a67f04 --- /dev/null +++ b/pkg/utils/ctx.go @@ -0,0 +1,12 @@ +package utils + +import "context" + +func IsCanceled(ctx context.Context) bool { + select { + case <-ctx.Done(): + return true + default: + return false + } +}