diff --git a/internal/aria2/add.go b/internal/aria2/add.go index 5db998f0..6f16674d 100644 --- a/internal/aria2/add.go +++ b/internal/aria2/add.go @@ -45,10 +45,10 @@ func AddURI(ctx context.Context, uri string, dstDirPath string) error { return errors.Wrapf(err, "failed to add uri %s", uri) } // TODO add to task manager - TaskManager.Submit(task.WithCancelCtx(&task.Task[string, interface{}]{ + TaskManager.Submit(task.WithCancelCtx(&task.Task[string]{ ID: gid, Name: fmt.Sprintf("download %s to [%s](%s)", uri, account.GetAccount().VirtualPath, dstDirActualPath), - Func: func(tsk *task.Task[string, interface{}]) error { + Func: func(tsk *task.Task[string]) error { m := &Monitor{ tsk: tsk, tempDir: tempDir, diff --git a/internal/aria2/aria2.go b/internal/aria2/aria2.go index 1aeaecf6..7a271431 100644 --- a/internal/aria2/aria2.go +++ b/internal/aria2/aria2.go @@ -8,7 +8,7 @@ import ( "time" ) -var TaskManager = task.NewTaskManager[string, interface{}](3) +var TaskManager = task.NewTaskManager[string](3) var notify = NewNotify() var client rpc.Client diff --git a/internal/aria2/monitor.go b/internal/aria2/monitor.go index 9e8ca9d3..fb55058d 100644 --- a/internal/aria2/monitor.go +++ b/internal/aria2/monitor.go @@ -17,7 +17,7 @@ import ( ) type Monitor struct { - tsk *task.Task[string, interface{}] + tsk *task.Task[string] tempDir string retried int c chan int @@ -92,7 +92,7 @@ func (m *Monitor) Update() (bool, error) { } } -var transferTaskManager = task.NewTaskManager[uint64, interface{}](3, func(k *uint64) { +var transferTaskManager = task.NewTaskManager[uint64](3, func(k *uint64) { atomic.AddUint64(k, 1) }) @@ -118,9 +118,9 @@ func (m *Monitor) Complete() error { } }() for _, file := range files { - transferTaskManager.Submit(task.WithCancelCtx[uint64](&task.Task[uint64, interface{}]{ + transferTaskManager.Submit(task.WithCancelCtx[uint64](&task.Task[uint64]{ Name: fmt.Sprintf("transfer %s to %s", file.Path, m.dstDirPath), - Func: func(tsk *task.Task[uint64, interface{}]) error { + Func: func(tsk *task.Task[uint64]) error { defer wg.Done() size, _ := strconv.ParseUint(file.Length, 10, 64) mimetype := mime.TypeByExtension(path.Ext(file.Path)) diff --git a/internal/fs/copy.go b/internal/fs/copy.go index 0f9f0966..413596e8 100644 --- a/internal/fs/copy.go +++ b/internal/fs/copy.go @@ -15,7 +15,7 @@ import ( "github.com/pkg/errors" ) -var CopyTaskManager = task.NewTaskManager[uint64, struct{}](3, func(tid *uint64) { +var CopyTaskManager = task.NewTaskManager[uint64](3, func(tid *uint64) { atomic.AddUint64(tid, 1) }) @@ -35,16 +35,16 @@ func Copy(ctx context.Context, account driver.Driver, srcObjPath, dstDirPath str return false, operations.Copy(ctx, account, srcObjActualPath, dstDirActualPath) } // not in an account - CopyTaskManager.Submit(task.WithCancelCtx(&task.Task[uint64, struct{}]{ + CopyTaskManager.Submit(task.WithCancelCtx(&task.Task[uint64]{ Name: fmt.Sprintf("copy [%s](%s) to [%s](%s)", srcAccount.GetAccount().VirtualPath, srcObjActualPath, dstAccount.GetAccount().VirtualPath, dstDirActualPath), - Func: func(task *task.Task[uint64, struct{}]) error { + Func: func(task *task.Task[uint64]) error { return CopyBetween2Accounts(task, srcAccount, dstAccount, srcObjActualPath, dstDirActualPath) }, })) return true, nil } -func CopyBetween2Accounts(t *task.Task[uint64, struct{}], srcAccount, dstAccount driver.Driver, srcObjPath, dstDirPath string) error { +func CopyBetween2Accounts(t *task.Task[uint64], srcAccount, dstAccount driver.Driver, srcObjPath, dstDirPath string) error { t.SetStatus("getting src object") srcObj, err := operations.Get(t.Ctx, srcAccount, srcObjPath) if err != nil { @@ -62,17 +62,17 @@ func CopyBetween2Accounts(t *task.Task[uint64, struct{}], srcAccount, dstAccount } srcObjPath := stdpath.Join(srcObjPath, obj.GetName()) dstObjPath := stdpath.Join(dstDirPath, obj.GetName()) - CopyTaskManager.Submit(task.WithCancelCtx(&task.Task[uint64, struct{}]{ + CopyTaskManager.Submit(task.WithCancelCtx(&task.Task[uint64]{ Name: fmt.Sprintf("copy [%s](%s) to [%s](%s)", srcAccount.GetAccount().VirtualPath, srcObjPath, dstAccount.GetAccount().VirtualPath, dstObjPath), - Func: func(t *task.Task[uint64, struct{}]) error { + Func: func(t *task.Task[uint64]) error { return CopyBetween2Accounts(t, srcAccount, dstAccount, srcObjPath, dstObjPath) }, })) } } else { - CopyTaskManager.Submit(task.WithCancelCtx(&task.Task[uint64, struct{}]{ + CopyTaskManager.Submit(task.WithCancelCtx(&task.Task[uint64]{ Name: fmt.Sprintf("copy [%s](%s) to [%s](%s)", srcAccount.GetAccount().VirtualPath, srcObjPath, dstAccount.GetAccount().VirtualPath, dstDirPath), - Func: func(t *task.Task[uint64, struct{}]) error { + Func: func(t *task.Task[uint64]) error { return CopyFileBetween2Accounts(t, srcAccount, dstAccount, srcObjPath, dstDirPath) }, })) @@ -80,7 +80,7 @@ func CopyBetween2Accounts(t *task.Task[uint64, struct{}], srcAccount, dstAccount return nil } -func CopyFileBetween2Accounts(tsk *task.Task[uint64, struct{}], srcAccount, dstAccount driver.Driver, srcFilePath, dstDirPath string) error { +func CopyFileBetween2Accounts(tsk *task.Task[uint64], srcAccount, dstAccount driver.Driver, srcFilePath, dstDirPath string) error { srcFile, err := operations.Get(tsk.Ctx, srcAccount, srcFilePath) if err != nil { return errors.WithMessagef(err, "failed get src [%s] file", srcFilePath) diff --git a/internal/fs/put.go b/internal/fs/put.go index 0581723e..38e6634b 100644 --- a/internal/fs/put.go +++ b/internal/fs/put.go @@ -11,7 +11,7 @@ import ( "sync/atomic" ) -var UploadTaskManager = task.NewTaskManager[uint64, struct{}](3, func(tid *uint64) { +var UploadTaskManager = task.NewTaskManager[uint64](3, func(tid *uint64) { atomic.AddUint64(tid, 1) }) @@ -24,9 +24,9 @@ func Put(ctx context.Context, account driver.Driver, dstDirPath string, file mod if err != nil { return errors.WithMessage(err, "failed get account") } - UploadTaskManager.Submit(task.WithCancelCtx(&task.Task[uint64, struct{}]{ + UploadTaskManager.Submit(task.WithCancelCtx(&task.Task[uint64]{ Name: fmt.Sprintf("upload %s to [%s](%s)", file.GetName(), account.GetAccount().VirtualPath, dstDirActualPath), - Func: func(task *task.Task[uint64, struct{}]) error { + Func: func(task *task.Task[uint64]) error { return operations.Put(task.Ctx, account, dstDirActualPath, file, nil) }, })) diff --git a/pkg/task/manager.go b/pkg/task/manager.go index 27ab9fd3..96aa259c 100644 --- a/pkg/task/manager.go +++ b/pkg/task/manager.go @@ -5,14 +5,14 @@ import ( log "github.com/sirupsen/logrus" ) -type Manager[K comparable, V any] struct { +type Manager[K comparable] struct { workerC chan struct{} curID K updateID func(*K) - tasks generic_sync.MapOf[K, *Task[K, V]] + tasks generic_sync.MapOf[K, *Task[K]] } -func (tm *Manager[K, V]) Submit(task *Task[K, V]) K { +func (tm *Manager[K]) Submit(task *Task[K]) K { if tm.updateID != nil { task.ID = tm.curID tm.updateID(&task.ID) @@ -22,7 +22,7 @@ func (tm *Manager[K, V]) Submit(task *Task[K, V]) K { return task.ID } -func (tm *Manager[K, V]) do(task *Task[K, V]) { +func (tm *Manager[K]) do(task *Task[K]) { go func() { log.Debugf("task [%s] waiting for worker", task.Name) select { @@ -36,20 +36,20 @@ func (tm *Manager[K, V]) do(task *Task[K, V]) { }() } -func (tm *Manager[K, V]) GetAll() []*Task[K, V] { +func (tm *Manager[K]) GetAll() []*Task[K] { return tm.tasks.Values() } -func (tm *Manager[K, V]) Get(tid K) (*Task[K, V], bool) { +func (tm *Manager[K]) Get(tid K) (*Task[K], bool) { return tm.tasks.Load(tid) } -func (tm *Manager[K, V]) MustGet(tid K) *Task[K, V] { +func (tm *Manager[K]) MustGet(tid K) *Task[K] { task, _ := tm.Get(tid) return task } -func (tm *Manager[K, V]) Retry(tid K) error { +func (tm *Manager[K]) Retry(tid K) error { t, ok := tm.Get(tid) if !ok { return ErrTaskNotFound @@ -58,7 +58,7 @@ func (tm *Manager[K, V]) Retry(tid K) error { return nil } -func (tm *Manager[K, V]) Cancel(tid K) error { +func (tm *Manager[K]) Cancel(tid K) error { t, ok := tm.Get(tid) if !ok { return ErrTaskNotFound @@ -67,17 +67,17 @@ func (tm *Manager[K, V]) Cancel(tid K) error { return nil } -func (tm *Manager[K, V]) Remove(tid K) { +func (tm *Manager[K]) Remove(tid K) { tm.tasks.Delete(tid) } // RemoveAll removes all tasks from the manager, this maybe shouldn't be used // because the task maybe still running. -func (tm *Manager[K, V]) RemoveAll() { +func (tm *Manager[K]) RemoveAll() { tm.tasks.Clear() } -func (tm *Manager[K, V]) RemoveFinished() { +func (tm *Manager[K]) RemoveFinished() { tasks := tm.GetAll() for _, task := range tasks { if task.Status == FINISHED { @@ -86,7 +86,7 @@ func (tm *Manager[K, V]) RemoveFinished() { } } -func (tm *Manager[K, V]) RemoveError() { +func (tm *Manager[K]) RemoveError() { tasks := tm.GetAll() for _, task := range tasks { if task.Error != nil { @@ -95,9 +95,9 @@ func (tm *Manager[K, V]) RemoveError() { } } -func NewTaskManager[K comparable, V any](maxWorker int, updateID ...func(*K)) *Manager[K, V] { - tm := &Manager[K, V]{ - tasks: generic_sync.MapOf[K, *Task[K, V]]{}, +func NewTaskManager[K comparable](maxWorker int, updateID ...func(*K)) *Manager[K] { + tm := &Manager[K]{ + tasks: generic_sync.MapOf[K, *Task[K]]{}, workerC: make(chan struct{}, maxWorker), } for i := 0; i < maxWorker; i++ { diff --git a/pkg/task/task.go b/pkg/task/task.go index 1a532c19..9a37a7af 100644 --- a/pkg/task/task.go +++ b/pkg/task/task.go @@ -16,34 +16,32 @@ var ( ERRORED = "errored" ) -type Func[K comparable, V any] func(task *Task[K, V]) error -type Callback[K comparable, V any] func(task *Task[K, V]) +type Func[K comparable] func(task *Task[K]) error +type Callback[K comparable] func(task *Task[K]) -type Task[K comparable, V any] struct { +type Task[K comparable] struct { ID K Name string Status string Error error - Data V - - Func Func[K, V] - callback Callback[K, V] + Func Func[K] + callback Callback[K] Ctx context.Context progress int cancel context.CancelFunc } -func (t *Task[K, V]) SetStatus(status string) { +func (t *Task[K]) SetStatus(status string) { t.Status = status } -func (t *Task[K, V]) SetProgress(percentage int) { +func (t *Task[K]) SetProgress(percentage int) { t.progress = percentage } -func (t *Task[K, V]) run() { +func (t *Task[K]) run() { t.Status = RUNNING defer func() { if err := recover(); err != nil { @@ -68,11 +66,11 @@ func (t *Task[K, V]) run() { } } -func (t *Task[K, V]) retry() { +func (t *Task[K]) retry() { t.run() } -func (t *Task[K, V]) Cancel() { +func (t *Task[K]) Cancel() { if t.Status == FINISHED || t.Status == CANCELED { return } @@ -83,7 +81,7 @@ func (t *Task[K, V]) Cancel() { t.Status = CANCELING } -func WithCancelCtx[K comparable, V any](task *Task[K, V]) *Task[K, V] { +func WithCancelCtx[K comparable](task *Task[K]) *Task[K] { ctx, cancel := context.WithCancel(context.Background()) task.Ctx = ctx task.cancel = cancel diff --git a/pkg/task/task_test.go b/pkg/task/task_test.go index 1719e412..b72c0868 100644 --- a/pkg/task/task_test.go +++ b/pkg/task/task_test.go @@ -9,12 +9,12 @@ import ( ) func TestTask_Manager(t *testing.T) { - tm := NewTaskManager[uint64, struct{}](3, func(id *uint64) { + tm := NewTaskManager[uint64](3, func(id *uint64) { atomic.AddUint64(id, 1) }) - id := tm.Submit(WithCancelCtx(&Task[uint64, struct{}]{ + id := tm.Submit(WithCancelCtx(&Task[uint64]{ Name: "test", - Func: func(task *Task[uint64, struct{}]) error { + Func: func(task *Task[uint64]) error { time.Sleep(time.Millisecond * 500) return nil }, @@ -34,12 +34,12 @@ func TestTask_Manager(t *testing.T) { } func TestTask_Cancel(t *testing.T) { - tm := NewTaskManager[uint64, struct{}](3, func(id *uint64) { + tm := NewTaskManager[uint64](3, func(id *uint64) { atomic.AddUint64(id, 1) }) - id := tm.Submit(WithCancelCtx(&Task[uint64, struct{}]{ + id := tm.Submit(WithCancelCtx(&Task[uint64]{ Name: "test", - Func: func(task *Task[uint64, struct{}]) error { + Func: func(task *Task[uint64]) error { for { if utils.IsCanceled(task.Ctx) { return nil @@ -62,13 +62,13 @@ func TestTask_Cancel(t *testing.T) { } func TestTask_Retry(t *testing.T) { - tm := NewTaskManager[uint64, struct{}](3, func(id *uint64) { + tm := NewTaskManager[uint64](3, func(id *uint64) { atomic.AddUint64(id, 1) }) num := 0 - id := tm.Submit(WithCancelCtx(&Task[uint64, struct{}]{ + id := tm.Submit(WithCancelCtx(&Task[uint64]{ Name: "test", - Func: func(task *Task[uint64, struct{}]) error { + Func: func(task *Task[uint64]) error { num++ if num&1 == 1 { return errors.New("test error")