diff --git a/internal/aria2/aria2_test.go b/internal/aria2/aria2_test.go index 610381e2..75bdc0dc 100644 --- a/internal/aria2/aria2_test.go +++ b/internal/aria2/aria2_test.go @@ -61,10 +61,10 @@ func TestDown(t *testing.T) { for { tsk := tasks[0] t.Logf("task: %+v", tsk) - if tsk.Status == task.FINISHED { + if tsk.GetState() == task.FINISHED { break } - if tsk.Status == task.ERRORED { + if tsk.GetState() == task.ERRORED { t.Fatalf("failed to download: %+v", tsk) } time.Sleep(time.Second) @@ -75,10 +75,10 @@ func TestDown(t *testing.T) { } tsk := transferTaskManager.GetAll()[0] t.Logf("task: %+v", tsk) - if tsk.Status == task.FINISHED { + if tsk.GetState() == task.FINISHED { break } - if tsk.Status == task.ERRORED { + if tsk.GetState() == task.ERRORED { t.Fatalf("failed to download: %+v", tsk) } time.Sleep(time.Second) diff --git a/pkg/task/manager.go b/pkg/task/manager.go index 96aa259c..20cc6087 100644 --- a/pkg/task/manager.go +++ b/pkg/task/manager.go @@ -2,6 +2,7 @@ package task import ( "github.com/alist-org/alist/v3/pkg/generic_sync" + "github.com/pkg/errors" log "github.com/sirupsen/logrus" ) @@ -52,7 +53,7 @@ func (tm *Manager[K]) MustGet(tid K) *Task[K] { func (tm *Manager[K]) Retry(tid K) error { t, ok := tm.Get(tid) if !ok { - return ErrTaskNotFound + return errors.WithStack(ErrTaskNotFound) } tm.do(t) return nil @@ -61,7 +62,7 @@ func (tm *Manager[K]) Retry(tid K) error { func (tm *Manager[K]) Cancel(tid K) error { t, ok := tm.Get(tid) if !ok { - return ErrTaskNotFound + return errors.WithStack(ErrTaskNotFound) } t.Cancel() return nil @@ -80,7 +81,7 @@ func (tm *Manager[K]) RemoveAll() { func (tm *Manager[K]) RemoveFinished() { tasks := tm.GetAll() for _, task := range tasks { - if task.Status == FINISHED { + if task.state == FINISHED { tm.Remove(task.ID) } } diff --git a/pkg/task/task.go b/pkg/task/task.go index 9a37a7af..a5d387b7 100644 --- a/pkg/task/task.go +++ b/pkg/task/task.go @@ -20,34 +20,40 @@ type Func[K comparable] func(task *Task[K]) error type Callback[K comparable] func(task *Task[K]) type Task[K comparable] struct { - ID K - Name string - Status string - Error error + ID K + Name string + state string // pending, running, finished, canceling, canceled, errored + status string + progress int + + Error error Func Func[K] callback Callback[K] - Ctx context.Context - progress int - cancel context.CancelFunc + Ctx context.Context + cancel context.CancelFunc } func (t *Task[K]) SetStatus(status string) { - t.Status = status + t.status = status } func (t *Task[K]) SetProgress(percentage int) { t.progress = percentage } +func (t *Task[K]) GetState() string { + return t.state +} + func (t *Task[K]) run() { - t.Status = RUNNING + t.state = RUNNING defer func() { if err := recover(); err != nil { log.Errorf("error [%+v] while run task [%s]", err, t.Name) t.Error = errors.Errorf("panic: %+v", err) - t.Status = ERRORED + t.state = ERRORED } }() t.Error = t.Func(t) @@ -55,11 +61,11 @@ func (t *Task[K]) run() { log.Errorf("error [%+v] while run task [%s]", t.Error, t.Name) } if errors.Is(t.Ctx.Err(), context.Canceled) { - t.Status = CANCELED + t.state = CANCELED } else if t.Error != nil { - t.Status = ERRORED + t.state = ERRORED } else { - t.Status = FINISHED + t.state = FINISHED if t.callback != nil { t.callback(t) } @@ -71,20 +77,20 @@ func (t *Task[K]) retry() { } func (t *Task[K]) Cancel() { - if t.Status == FINISHED || t.Status == CANCELED { + if t.state == FINISHED || t.state == CANCELED { return } if t.cancel != nil { t.cancel() } // maybe can't cancel - t.Status = CANCELING + t.state = CANCELING } func WithCancelCtx[K comparable](task *Task[K]) *Task[K] { ctx, cancel := context.WithCancel(context.Background()) task.Ctx = ctx task.cancel = cancel - task.Status = PENDING + task.state = PENDING return task } diff --git a/pkg/task/task_test.go b/pkg/task/task_test.go index b72c0868..a03841e0 100644 --- a/pkg/task/task_test.go +++ b/pkg/task/task_test.go @@ -24,12 +24,12 @@ func TestTask_Manager(t *testing.T) { t.Fatal("task not found") } time.Sleep(time.Millisecond * 100) - if task.Status != RUNNING { - t.Errorf("task status not running: %s", task.Status) + if task.state != RUNNING { + t.Errorf("task status not running: %s", task.state) } time.Sleep(time.Second) - if task.Status != FINISHED { - t.Errorf("task status not finished: %s", task.Status) + if task.state != FINISHED { + t.Errorf("task status not finished: %s", task.state) } } @@ -56,8 +56,8 @@ func TestTask_Cancel(t *testing.T) { time.Sleep(time.Microsecond * 50) task.Cancel() time.Sleep(time.Millisecond) - if task.Status != CANCELED { - t.Errorf("task status not canceled: %s", task.Status) + if task.state != CANCELED { + t.Errorf("task status not canceled: %s", task.state) } } @@ -82,7 +82,7 @@ func TestTask_Retry(t *testing.T) { } time.Sleep(time.Millisecond) if task.Error == nil { - t.Error(task.Status) + t.Error(task.state) t.Fatal("task error is nil, but expected error") } else { t.Logf("task error: %s", task.Error)