From 55d6434daab818aead38f3773d8fd5ca96fb056a Mon Sep 17 00:00:00 2001 From: Noah Hsu Date: Tue, 21 Jun 2022 16:14:37 +0800 Subject: [PATCH] refactor(task): generic task manager --- internal/aria2/notify.go | 3 +- internal/fs/copy.go | 42 +++++++++++++----------- internal/fs/put.go | 14 +++++--- pkg/task/manager.go | 69 +++++++++++++++++++++------------------- pkg/task/task.go | 55 +++++++++++++++----------------- pkg/task/task_test.go | 58 +++++++++++++++++++++------------ 6 files changed, 135 insertions(+), 106 deletions(-) diff --git a/internal/aria2/notify.go b/internal/aria2/notify.go index 056fe514..a4241894 100644 --- a/internal/aria2/notify.go +++ b/internal/aria2/notify.go @@ -6,7 +6,8 @@ import ( ) const ( - Downloading = iota + Ready = iota + Downloading Paused Stopped Completed diff --git a/internal/fs/copy.go b/internal/fs/copy.go index b2bff07f..c5244133 100644 --- a/internal/fs/copy.go +++ b/internal/fs/copy.go @@ -4,6 +4,7 @@ import ( "context" "fmt" stdpath "path" + "sync/atomic" "github.com/alist-org/alist/v3/pkg/task" "github.com/alist-org/alist/v3/pkg/utils" @@ -14,7 +15,9 @@ import ( "github.com/pkg/errors" ) -var CopyTaskManager = task.NewTaskManager() +var CopyTaskManager = task.NewTaskManager[uint64, struct{}](3, func(tid *uint64) { + atomic.AddUint64(tid, 1) +}) // Copy if in an account, call move method // if not, add copy task @@ -32,15 +35,16 @@ func Copy(ctx context.Context, account driver.Driver, srcPath, dstPath string) ( return false, operations.Copy(ctx, account, srcActualPath, dstActualPath) } // not in an account - CopyTaskManager.Submit( - fmt.Sprintf("copy [%s](%s) to [%s](%s)", srcAccount.GetAccount().VirtualPath, srcActualPath, dstAccount.GetAccount().VirtualPath, dstActualPath), - func(task *task.Task) error { + CopyTaskManager.Submit(task.WithCancelCtx(&task.Task[uint64, struct{}]{ + Name: fmt.Sprintf("copy [%s](%s) to [%s](%s)", srcAccount.GetAccount().VirtualPath, srcActualPath, dstAccount.GetAccount().VirtualPath, dstActualPath), + Func: func(task *task.Task[uint64, struct{}]) error { return CopyBetween2Accounts(task, srcAccount, dstAccount, srcActualPath, dstActualPath) - }) + }, + })) return true, nil } -func CopyBetween2Accounts(t *task.Task, srcAccount, dstAccount driver.Driver, srcPath, dstPath string) error { +func CopyBetween2Accounts(t *task.Task[uint64, struct{}], srcAccount, dstAccount driver.Driver, srcPath, dstPath string) error { t.SetStatus("getting src object") srcObj, err := operations.Get(t.Ctx, srcAccount, srcPath) if err != nil { @@ -58,28 +62,30 @@ func CopyBetween2Accounts(t *task.Task, srcAccount, dstAccount driver.Driver, sr } srcObjPath := stdpath.Join(srcPath, obj.GetName()) dstObjPath := stdpath.Join(dstPath, obj.GetName()) - CopyTaskManager.Submit( - fmt.Sprintf("copy [%s](%s) to [%s](%s)", srcAccount.GetAccount().VirtualPath, srcObjPath, dstAccount.GetAccount().VirtualPath, dstObjPath), - func(t *task.Task) error { + CopyTaskManager.Submit(task.WithCancelCtx(&task.Task[uint64, struct{}]{ + Name: fmt.Sprintf("copy [%s](%s) to [%s](%s)", srcAccount.GetAccount().VirtualPath, srcObjPath, dstAccount.GetAccount().VirtualPath, dstObjPath), + Func: func(t *task.Task[uint64, struct{}]) error { return CopyBetween2Accounts(t, srcAccount, dstAccount, srcObjPath, dstObjPath) - }) + }, + })) } } else { - CopyTaskManager.Submit( - fmt.Sprintf("copy [%s](%s) to [%s](%s)", srcAccount.GetAccount().VirtualPath, srcPath, dstAccount.GetAccount().VirtualPath, dstPath), - func(t *task.Task) error { + CopyTaskManager.Submit(task.WithCancelCtx(&task.Task[uint64, struct{}]{ + Name: fmt.Sprintf("copy [%s](%s) to [%s](%s)", srcAccount.GetAccount().VirtualPath, srcPath, dstAccount.GetAccount().VirtualPath, dstPath), + Func: func(t *task.Task[uint64, struct{}]) error { return CopyFileBetween2Accounts(t, srcAccount, dstAccount, srcPath, dstPath) - }) + }, + })) } return nil } -func CopyFileBetween2Accounts(t *task.Task, srcAccount, dstAccount driver.Driver, srcPath, dstPath string) error { - srcFile, err := operations.Get(t.Ctx, srcAccount, srcPath) +func CopyFileBetween2Accounts(tsk *task.Task[uint64, struct{}], srcAccount, dstAccount driver.Driver, srcPath, dstPath string) error { + srcFile, err := operations.Get(tsk.Ctx, srcAccount, srcPath) if err != nil { return errors.WithMessagef(err, "failed get src [%s] file", srcPath) } - link, err := operations.Link(t.Ctx, srcAccount, srcPath, model.LinkArgs{}) + link, err := operations.Link(tsk.Ctx, srcAccount, srcPath, model.LinkArgs{}) if err != nil { return errors.WithMessagef(err, "failed get [%s] link", srcPath) } @@ -87,5 +93,5 @@ func CopyFileBetween2Accounts(t *task.Task, srcAccount, dstAccount driver.Driver if err != nil { return errors.WithMessagef(err, "failed get [%s] stream", srcPath) } - return operations.Put(t.Ctx, dstAccount, dstPath, stream, t.SetProgress) + return operations.Put(tsk.Ctx, dstAccount, dstPath, stream, tsk.SetProgress) } diff --git a/internal/fs/put.go b/internal/fs/put.go index 12123186..65d01726 100644 --- a/internal/fs/put.go +++ b/internal/fs/put.go @@ -8,9 +8,12 @@ import ( "github.com/alist-org/alist/v3/internal/operations" "github.com/alist-org/alist/v3/pkg/task" "github.com/pkg/errors" + "sync/atomic" ) -var UploadTaskManager = task.NewTaskManager() +var UploadTaskManager = task.NewTaskManager[uint64, struct{}](3, func(tid *uint64) { + atomic.AddUint64(tid, 1) +}) // Put add as a put task func Put(ctx context.Context, account driver.Driver, dstDir string, file model.FileStreamer) error { @@ -21,8 +24,11 @@ func Put(ctx context.Context, account driver.Driver, dstDir string, file model.F if err != nil { return errors.WithMessage(err, "failed get account") } - UploadTaskManager.Submit(fmt.Sprintf("upload %s to [%s](%s)", file.GetName(), account.GetAccount().VirtualPath, actualParentPath), func(task *task.Task) error { - return operations.Put(task.Ctx, account, actualParentPath, file, nil) - }) + UploadTaskManager.Submit(task.WithCancelCtx(&task.Task[uint64, struct{}]{ + Name: fmt.Sprintf("upload %s to [%s](%s)", file.GetName(), account.GetAccount().VirtualPath, actualParentPath), + Func: func(task *task.Task[uint64, struct{}]) error { + return operations.Put(task.Ctx, account, actualParentPath, file, nil) + }, + })) return nil } diff --git a/pkg/task/manager.go b/pkg/task/manager.go index 667fa932..27ab9fd3 100644 --- a/pkg/task/manager.go +++ b/pkg/task/manager.go @@ -1,27 +1,28 @@ package task import ( - log "github.com/sirupsen/logrus" - "sync/atomic" - "github.com/alist-org/alist/v3/pkg/generic_sync" + log "github.com/sirupsen/logrus" ) -type Manager struct { - workerC chan struct{} - curID uint64 - tasks generic_sync.MapOf[uint64, *Task] +type Manager[K comparable, V any] struct { + workerC chan struct{} + curID K + updateID func(*K) + tasks generic_sync.MapOf[K, *Task[K, V]] } -func (tm *Manager) Submit(name string, f Func, callbacks ...Callback) uint64 { - task := newTask(name, f, callbacks...) - tm.addTask(task) - tm.do(task.ID) +func (tm *Manager[K, V]) Submit(task *Task[K, V]) K { + if tm.updateID != nil { + task.ID = tm.curID + tm.updateID(&task.ID) + } + tm.tasks.Store(task.ID, task) + tm.do(task) return task.ID } -func (tm *Manager) do(tid uint64) { - task := tm.MustGet(tid) +func (tm *Manager[K, V]) do(task *Task[K, V]) { go func() { log.Debugf("task [%s] waiting for worker", task.Name) select { @@ -30,39 +31,34 @@ func (tm *Manager) do(tid uint64) { task.run() log.Debugf("task [%s] ended", task.Name) } + // return worker tm.workerC <- struct{}{} }() } -func (tm *Manager) addTask(task *Task) { - task.ID = tm.curID - atomic.AddUint64(&tm.curID, 1) - tm.tasks.Store(task.ID, task) -} - -func (tm *Manager) GetAll() []*Task { +func (tm *Manager[K, V]) GetAll() []*Task[K, V] { return tm.tasks.Values() } -func (tm *Manager) Get(tid uint64) (*Task, bool) { +func (tm *Manager[K, V]) Get(tid K) (*Task[K, V], bool) { return tm.tasks.Load(tid) } -func (tm *Manager) MustGet(tid uint64) *Task { +func (tm *Manager[K, V]) MustGet(tid K) *Task[K, V] { task, _ := tm.Get(tid) return task } -func (tm *Manager) Retry(tid uint64) error { +func (tm *Manager[K, V]) Retry(tid K) error { t, ok := tm.Get(tid) if !ok { return ErrTaskNotFound } - tm.do(t.ID) + tm.do(t) return nil } -func (tm *Manager) Cancel(tid uint64) error { +func (tm *Manager[K, V]) Cancel(tid K) error { t, ok := tm.Get(tid) if !ok { return ErrTaskNotFound @@ -71,17 +67,17 @@ func (tm *Manager) Cancel(tid uint64) error { return nil } -func (tm *Manager) Remove(tid uint64) { +func (tm *Manager[K, V]) Remove(tid K) { tm.tasks.Delete(tid) } // RemoveAll removes all tasks from the manager, this maybe shouldn't be used // because the task maybe still running. -func (tm *Manager) RemoveAll() { +func (tm *Manager[K, V]) RemoveAll() { tm.tasks.Clear() } -func (tm *Manager) RemoveFinished() { +func (tm *Manager[K, V]) RemoveFinished() { tasks := tm.GetAll() for _, task := range tasks { if task.Status == FINISHED { @@ -90,7 +86,7 @@ func (tm *Manager) RemoveFinished() { } } -func (tm *Manager) RemoveError() { +func (tm *Manager[K, V]) RemoveError() { tasks := tm.GetAll() for _, task := range tasks { if task.Error != nil { @@ -99,9 +95,16 @@ func (tm *Manager) RemoveError() { } } -func NewTaskManager() *Manager { - return &Manager{ - tasks: generic_sync.MapOf[uint64, *Task]{}, - curID: 0, +func NewTaskManager[K comparable, V any](maxWorker int, updateID ...func(*K)) *Manager[K, V] { + tm := &Manager[K, V]{ + tasks: generic_sync.MapOf[K, *Task[K, V]]{}, + workerC: make(chan struct{}, maxWorker), } + for i := 0; i < maxWorker; i++ { + tm.workerC <- struct{}{} + } + if len(updateID) > 0 { + tm.updateID = updateID[0] + } + return tm } diff --git a/pkg/task/task.go b/pkg/task/task.go index bbbc24fb..5c69b381 100644 --- a/pkg/task/task.go +++ b/pkg/task/task.go @@ -16,45 +16,34 @@ var ( ERRORED = "errored" ) -type Func func(task *Task) error -type Callback func(task *Task) +type Func[K comparable, V any] func(task *Task[K, V]) error +type Callback[K comparable, V any] func(task *Task[K, V]) + +type Task[K comparable, V any] struct { + ID K + Name string + Status string + Error error + + Data V + + Func Func[K, V] + callback Callback[K, V] -type Task struct { - ID uint64 - Name string - Status string - Error error - Func Func Ctx context.Context progress int - callback Callback cancel context.CancelFunc } -func newTask(name string, func_ Func, callbacks ...Callback) *Task { - ctx, cancel := context.WithCancel(context.Background()) - t := &Task{ - Name: name, - Status: PENDING, - Func: func_, - Ctx: ctx, - cancel: cancel, - } - if len(callbacks) > 0 { - t.callback = callbacks[0] - } - return t -} - -func (t *Task) SetStatus(status string) { +func (t *Task[K, V]) SetStatus(status string) { t.Status = status } -func (t *Task) SetProgress(percentage int) { +func (t *Task[K, V]) SetProgress(percentage int) { t.progress = percentage } -func (t *Task) run() { +func (t *Task[K, V]) run() { t.Status = RUNNING defer func() { if err := recover(); err != nil { @@ -76,11 +65,11 @@ func (t *Task) run() { } } -func (t *Task) retry() { +func (t *Task[K, V]) retry() { t.run() } -func (t *Task) Cancel() { +func (t *Task[K, V]) Cancel() { if t.Status == FINISHED || t.Status == CANCELED { return } @@ -90,3 +79,11 @@ func (t *Task) Cancel() { // maybe can't cancel t.Status = CANCELING } + +func WithCancelCtx[K comparable, V any](task *Task[K, V]) *Task[K, V] { + ctx, cancel := context.WithCancel(context.Background()) + task.Ctx = ctx + task.cancel = cancel + task.Status = PENDING + return task +} diff --git a/pkg/task/task_test.go b/pkg/task/task_test.go index 84e1ca0b..1719e412 100644 --- a/pkg/task/task_test.go +++ b/pkg/task/task_test.go @@ -3,16 +3,22 @@ package task import ( "github.com/alist-org/alist/v3/pkg/utils" "github.com/pkg/errors" + "sync/atomic" "testing" "time" ) func TestTask_Manager(t *testing.T) { - tm := NewTaskManager() - id := tm.Submit("test", func(task *Task) error { - time.Sleep(time.Millisecond * 500) - return nil + tm := NewTaskManager[uint64, struct{}](3, func(id *uint64) { + atomic.AddUint64(id, 1) }) + id := tm.Submit(WithCancelCtx(&Task[uint64, struct{}]{ + Name: "test", + Func: func(task *Task[uint64, struct{}]) error { + time.Sleep(time.Millisecond * 500) + return nil + }, + })) task, ok := tm.Get(id) if !ok { t.Fatal("task not found") @@ -28,16 +34,21 @@ func TestTask_Manager(t *testing.T) { } func TestTask_Cancel(t *testing.T) { - tm := NewTaskManager() - id := tm.Submit("test", func(task *Task) error { - for { - if utils.IsCanceled(task.Ctx) { - return nil - } else { - t.Logf("task is running") - } - } + tm := NewTaskManager[uint64, struct{}](3, func(id *uint64) { + atomic.AddUint64(id, 1) }) + id := tm.Submit(WithCancelCtx(&Task[uint64, struct{}]{ + Name: "test", + Func: func(task *Task[uint64, struct{}]) error { + for { + if utils.IsCanceled(task.Ctx) { + return nil + } else { + t.Logf("task is running") + } + } + }, + })) task, ok := tm.Get(id) if !ok { t.Fatal("task not found") @@ -51,15 +62,20 @@ func TestTask_Cancel(t *testing.T) { } func TestTask_Retry(t *testing.T) { - tm := NewTaskManager() - num := 0 - id := tm.Submit("test", func(task *Task) error { - num++ - if num&1 == 1 { - return errors.New("test error") - } - return nil + tm := NewTaskManager[uint64, struct{}](3, func(id *uint64) { + atomic.AddUint64(id, 1) }) + num := 0 + id := tm.Submit(WithCancelCtx(&Task[uint64, struct{}]{ + Name: "test", + Func: func(task *Task[uint64, struct{}]) error { + num++ + if num&1 == 1 { + return errors.New("test error") + } + return nil + }, + })) task, ok := tm.Get(id) if !ok { t.Fatal("task not found")