mirror of https://github.com/Xhofe/alist
				
				
				
			refactor(task): generic task manager
							parent
							
								
									1b3387ca1a
								
							
						
					
					
						commit
						55d6434daa
					
				|  | @ -6,7 +6,8 @@ import ( | |||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	Downloading = iota | ||||
| 	Ready = iota | ||||
| 	Downloading | ||||
| 	Paused | ||||
| 	Stopped | ||||
| 	Completed | ||||
|  |  | |||
|  | @ -4,6 +4,7 @@ import ( | |||
| 	"context" | ||||
| 	"fmt" | ||||
| 	stdpath "path" | ||||
| 	"sync/atomic" | ||||
| 
 | ||||
| 	"github.com/alist-org/alist/v3/pkg/task" | ||||
| 	"github.com/alist-org/alist/v3/pkg/utils" | ||||
|  | @ -14,7 +15,9 @@ import ( | |||
| 	"github.com/pkg/errors" | ||||
| ) | ||||
| 
 | ||||
| var CopyTaskManager = task.NewTaskManager() | ||||
| var CopyTaskManager = task.NewTaskManager[uint64, struct{}](3, func(tid *uint64) { | ||||
| 	atomic.AddUint64(tid, 1) | ||||
| }) | ||||
| 
 | ||||
| // Copy if in an account, call move method
 | ||||
| // if not, add copy task
 | ||||
|  | @ -32,15 +35,16 @@ func Copy(ctx context.Context, account driver.Driver, srcPath, dstPath string) ( | |||
| 		return false, operations.Copy(ctx, account, srcActualPath, dstActualPath) | ||||
| 	} | ||||
| 	// not in an account
 | ||||
| 	CopyTaskManager.Submit( | ||||
| 		fmt.Sprintf("copy [%s](%s) to [%s](%s)", srcAccount.GetAccount().VirtualPath, srcActualPath, dstAccount.GetAccount().VirtualPath, dstActualPath), | ||||
| 		func(task *task.Task) error { | ||||
| 	CopyTaskManager.Submit(task.WithCancelCtx(&task.Task[uint64, struct{}]{ | ||||
| 		Name: fmt.Sprintf("copy [%s](%s) to [%s](%s)", srcAccount.GetAccount().VirtualPath, srcActualPath, dstAccount.GetAccount().VirtualPath, dstActualPath), | ||||
| 		Func: func(task *task.Task[uint64, struct{}]) error { | ||||
| 			return CopyBetween2Accounts(task, srcAccount, dstAccount, srcActualPath, dstActualPath) | ||||
| 		}) | ||||
| 		}, | ||||
| 	})) | ||||
| 	return true, nil | ||||
| } | ||||
| 
 | ||||
| func CopyBetween2Accounts(t *task.Task, srcAccount, dstAccount driver.Driver, srcPath, dstPath string) error { | ||||
| func CopyBetween2Accounts(t *task.Task[uint64, struct{}], srcAccount, dstAccount driver.Driver, srcPath, dstPath string) error { | ||||
| 	t.SetStatus("getting src object") | ||||
| 	srcObj, err := operations.Get(t.Ctx, srcAccount, srcPath) | ||||
| 	if err != nil { | ||||
|  | @ -58,28 +62,30 @@ func CopyBetween2Accounts(t *task.Task, srcAccount, dstAccount driver.Driver, sr | |||
| 			} | ||||
| 			srcObjPath := stdpath.Join(srcPath, obj.GetName()) | ||||
| 			dstObjPath := stdpath.Join(dstPath, obj.GetName()) | ||||
| 			CopyTaskManager.Submit( | ||||
| 				fmt.Sprintf("copy [%s](%s) to [%s](%s)", srcAccount.GetAccount().VirtualPath, srcObjPath, dstAccount.GetAccount().VirtualPath, dstObjPath), | ||||
| 				func(t *task.Task) error { | ||||
| 			CopyTaskManager.Submit(task.WithCancelCtx(&task.Task[uint64, struct{}]{ | ||||
| 				Name: fmt.Sprintf("copy [%s](%s) to [%s](%s)", srcAccount.GetAccount().VirtualPath, srcObjPath, dstAccount.GetAccount().VirtualPath, dstObjPath), | ||||
| 				Func: func(t *task.Task[uint64, struct{}]) error { | ||||
| 					return CopyBetween2Accounts(t, srcAccount, dstAccount, srcObjPath, dstObjPath) | ||||
| 				}) | ||||
| 				}, | ||||
| 			})) | ||||
| 		} | ||||
| 	} else { | ||||
| 		CopyTaskManager.Submit( | ||||
| 			fmt.Sprintf("copy [%s](%s) to [%s](%s)", srcAccount.GetAccount().VirtualPath, srcPath, dstAccount.GetAccount().VirtualPath, dstPath), | ||||
| 			func(t *task.Task) error { | ||||
| 		CopyTaskManager.Submit(task.WithCancelCtx(&task.Task[uint64, struct{}]{ | ||||
| 			Name: fmt.Sprintf("copy [%s](%s) to [%s](%s)", srcAccount.GetAccount().VirtualPath, srcPath, dstAccount.GetAccount().VirtualPath, dstPath), | ||||
| 			Func: func(t *task.Task[uint64, struct{}]) error { | ||||
| 				return CopyFileBetween2Accounts(t, srcAccount, dstAccount, srcPath, dstPath) | ||||
| 			}) | ||||
| 			}, | ||||
| 		})) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func CopyFileBetween2Accounts(t *task.Task, srcAccount, dstAccount driver.Driver, srcPath, dstPath string) error { | ||||
| 	srcFile, err := operations.Get(t.Ctx, srcAccount, srcPath) | ||||
| func CopyFileBetween2Accounts(tsk *task.Task[uint64, struct{}], srcAccount, dstAccount driver.Driver, srcPath, dstPath string) error { | ||||
| 	srcFile, err := operations.Get(tsk.Ctx, srcAccount, srcPath) | ||||
| 	if err != nil { | ||||
| 		return errors.WithMessagef(err, "failed get src [%s] file", srcPath) | ||||
| 	} | ||||
| 	link, err := operations.Link(t.Ctx, srcAccount, srcPath, model.LinkArgs{}) | ||||
| 	link, err := operations.Link(tsk.Ctx, srcAccount, srcPath, model.LinkArgs{}) | ||||
| 	if err != nil { | ||||
| 		return errors.WithMessagef(err, "failed get [%s] link", srcPath) | ||||
| 	} | ||||
|  | @ -87,5 +93,5 @@ func CopyFileBetween2Accounts(t *task.Task, srcAccount, dstAccount driver.Driver | |||
| 	if err != nil { | ||||
| 		return errors.WithMessagef(err, "failed get [%s] stream", srcPath) | ||||
| 	} | ||||
| 	return operations.Put(t.Ctx, dstAccount, dstPath, stream, t.SetProgress) | ||||
| 	return operations.Put(tsk.Ctx, dstAccount, dstPath, stream, tsk.SetProgress) | ||||
| } | ||||
|  |  | |||
|  | @ -8,9 +8,12 @@ import ( | |||
| 	"github.com/alist-org/alist/v3/internal/operations" | ||||
| 	"github.com/alist-org/alist/v3/pkg/task" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"sync/atomic" | ||||
| ) | ||||
| 
 | ||||
| var UploadTaskManager = task.NewTaskManager() | ||||
| var UploadTaskManager = task.NewTaskManager[uint64, struct{}](3, func(tid *uint64) { | ||||
| 	atomic.AddUint64(tid, 1) | ||||
| }) | ||||
| 
 | ||||
| // Put add as a put task
 | ||||
| func Put(ctx context.Context, account driver.Driver, dstDir string, file model.FileStreamer) error { | ||||
|  | @ -21,8 +24,11 @@ func Put(ctx context.Context, account driver.Driver, dstDir string, file model.F | |||
| 	if err != nil { | ||||
| 		return errors.WithMessage(err, "failed get account") | ||||
| 	} | ||||
| 	UploadTaskManager.Submit(fmt.Sprintf("upload %s to [%s](%s)", file.GetName(), account.GetAccount().VirtualPath, actualParentPath), func(task *task.Task) error { | ||||
| 		return operations.Put(task.Ctx, account, actualParentPath, file, nil) | ||||
| 	}) | ||||
| 	UploadTaskManager.Submit(task.WithCancelCtx(&task.Task[uint64, struct{}]{ | ||||
| 		Name: fmt.Sprintf("upload %s to [%s](%s)", file.GetName(), account.GetAccount().VirtualPath, actualParentPath), | ||||
| 		Func: func(task *task.Task[uint64, struct{}]) error { | ||||
| 			return operations.Put(task.Ctx, account, actualParentPath, file, nil) | ||||
| 		}, | ||||
| 	})) | ||||
| 	return nil | ||||
| } | ||||
|  |  | |||
|  | @ -1,27 +1,28 @@ | |||
| package task | ||||
| 
 | ||||
| import ( | ||||
| 	log "github.com/sirupsen/logrus" | ||||
| 	"sync/atomic" | ||||
| 
 | ||||
| 	"github.com/alist-org/alist/v3/pkg/generic_sync" | ||||
| 	log "github.com/sirupsen/logrus" | ||||
| ) | ||||
| 
 | ||||
| type Manager struct { | ||||
| 	workerC chan struct{} | ||||
| 	curID   uint64 | ||||
| 	tasks   generic_sync.MapOf[uint64, *Task] | ||||
| type Manager[K comparable, V any] struct { | ||||
| 	workerC  chan struct{} | ||||
| 	curID    K | ||||
| 	updateID func(*K) | ||||
| 	tasks    generic_sync.MapOf[K, *Task[K, V]] | ||||
| } | ||||
| 
 | ||||
| func (tm *Manager) Submit(name string, f Func, callbacks ...Callback) uint64 { | ||||
| 	task := newTask(name, f, callbacks...) | ||||
| 	tm.addTask(task) | ||||
| 	tm.do(task.ID) | ||||
| func (tm *Manager[K, V]) Submit(task *Task[K, V]) K { | ||||
| 	if tm.updateID != nil { | ||||
| 		task.ID = tm.curID | ||||
| 		tm.updateID(&task.ID) | ||||
| 	} | ||||
| 	tm.tasks.Store(task.ID, task) | ||||
| 	tm.do(task) | ||||
| 	return task.ID | ||||
| } | ||||
| 
 | ||||
| func (tm *Manager) do(tid uint64) { | ||||
| 	task := tm.MustGet(tid) | ||||
| func (tm *Manager[K, V]) do(task *Task[K, V]) { | ||||
| 	go func() { | ||||
| 		log.Debugf("task [%s] waiting for worker", task.Name) | ||||
| 		select { | ||||
|  | @ -30,39 +31,34 @@ func (tm *Manager) do(tid uint64) { | |||
| 			task.run() | ||||
| 			log.Debugf("task [%s] ended", task.Name) | ||||
| 		} | ||||
| 		// return worker
 | ||||
| 		tm.workerC <- struct{}{} | ||||
| 	}() | ||||
| } | ||||
| 
 | ||||
| func (tm *Manager) addTask(task *Task) { | ||||
| 	task.ID = tm.curID | ||||
| 	atomic.AddUint64(&tm.curID, 1) | ||||
| 	tm.tasks.Store(task.ID, task) | ||||
| } | ||||
| 
 | ||||
| func (tm *Manager) GetAll() []*Task { | ||||
| func (tm *Manager[K, V]) GetAll() []*Task[K, V] { | ||||
| 	return tm.tasks.Values() | ||||
| } | ||||
| 
 | ||||
| func (tm *Manager) Get(tid uint64) (*Task, bool) { | ||||
| func (tm *Manager[K, V]) Get(tid K) (*Task[K, V], bool) { | ||||
| 	return tm.tasks.Load(tid) | ||||
| } | ||||
| 
 | ||||
| func (tm *Manager) MustGet(tid uint64) *Task { | ||||
| func (tm *Manager[K, V]) MustGet(tid K) *Task[K, V] { | ||||
| 	task, _ := tm.Get(tid) | ||||
| 	return task | ||||
| } | ||||
| 
 | ||||
| func (tm *Manager) Retry(tid uint64) error { | ||||
| func (tm *Manager[K, V]) Retry(tid K) error { | ||||
| 	t, ok := tm.Get(tid) | ||||
| 	if !ok { | ||||
| 		return ErrTaskNotFound | ||||
| 	} | ||||
| 	tm.do(t.ID) | ||||
| 	tm.do(t) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (tm *Manager) Cancel(tid uint64) error { | ||||
| func (tm *Manager[K, V]) Cancel(tid K) error { | ||||
| 	t, ok := tm.Get(tid) | ||||
| 	if !ok { | ||||
| 		return ErrTaskNotFound | ||||
|  | @ -71,17 +67,17 @@ func (tm *Manager) Cancel(tid uint64) error { | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (tm *Manager) Remove(tid uint64) { | ||||
| func (tm *Manager[K, V]) Remove(tid K) { | ||||
| 	tm.tasks.Delete(tid) | ||||
| } | ||||
| 
 | ||||
| // RemoveAll removes all tasks from the manager, this maybe shouldn't be used
 | ||||
| // because the task maybe still running.
 | ||||
| func (tm *Manager) RemoveAll() { | ||||
| func (tm *Manager[K, V]) RemoveAll() { | ||||
| 	tm.tasks.Clear() | ||||
| } | ||||
| 
 | ||||
| func (tm *Manager) RemoveFinished() { | ||||
| func (tm *Manager[K, V]) RemoveFinished() { | ||||
| 	tasks := tm.GetAll() | ||||
| 	for _, task := range tasks { | ||||
| 		if task.Status == FINISHED { | ||||
|  | @ -90,7 +86,7 @@ func (tm *Manager) RemoveFinished() { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (tm *Manager) RemoveError() { | ||||
| func (tm *Manager[K, V]) RemoveError() { | ||||
| 	tasks := tm.GetAll() | ||||
| 	for _, task := range tasks { | ||||
| 		if task.Error != nil { | ||||
|  | @ -99,9 +95,16 @@ func (tm *Manager) RemoveError() { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func NewTaskManager() *Manager { | ||||
| 	return &Manager{ | ||||
| 		tasks: generic_sync.MapOf[uint64, *Task]{}, | ||||
| 		curID: 0, | ||||
| func NewTaskManager[K comparable, V any](maxWorker int, updateID ...func(*K)) *Manager[K, V] { | ||||
| 	tm := &Manager[K, V]{ | ||||
| 		tasks:   generic_sync.MapOf[K, *Task[K, V]]{}, | ||||
| 		workerC: make(chan struct{}, maxWorker), | ||||
| 	} | ||||
| 	for i := 0; i < maxWorker; i++ { | ||||
| 		tm.workerC <- struct{}{} | ||||
| 	} | ||||
| 	if len(updateID) > 0 { | ||||
| 		tm.updateID = updateID[0] | ||||
| 	} | ||||
| 	return tm | ||||
| } | ||||
|  |  | |||
|  | @ -16,45 +16,34 @@ var ( | |||
| 	ERRORED   = "errored" | ||||
| ) | ||||
| 
 | ||||
| type Func func(task *Task) error | ||||
| type Callback func(task *Task) | ||||
| type Func[K comparable, V any] func(task *Task[K, V]) error | ||||
| type Callback[K comparable, V any] func(task *Task[K, V]) | ||||
| 
 | ||||
| type Task[K comparable, V any] struct { | ||||
| 	ID     K | ||||
| 	Name   string | ||||
| 	Status string | ||||
| 	Error  error | ||||
| 
 | ||||
| 	Data V | ||||
| 
 | ||||
| 	Func     Func[K, V] | ||||
| 	callback Callback[K, V] | ||||
| 
 | ||||
| type Task struct { | ||||
| 	ID       uint64 | ||||
| 	Name     string | ||||
| 	Status   string | ||||
| 	Error    error | ||||
| 	Func     Func | ||||
| 	Ctx      context.Context | ||||
| 	progress int | ||||
| 	callback Callback | ||||
| 	cancel   context.CancelFunc | ||||
| } | ||||
| 
 | ||||
| func newTask(name string, func_ Func, callbacks ...Callback) *Task { | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
| 	t := &Task{ | ||||
| 		Name:   name, | ||||
| 		Status: PENDING, | ||||
| 		Func:   func_, | ||||
| 		Ctx:    ctx, | ||||
| 		cancel: cancel, | ||||
| 	} | ||||
| 	if len(callbacks) > 0 { | ||||
| 		t.callback = callbacks[0] | ||||
| 	} | ||||
| 	return t | ||||
| } | ||||
| 
 | ||||
| func (t *Task) SetStatus(status string) { | ||||
| func (t *Task[K, V]) SetStatus(status string) { | ||||
| 	t.Status = status | ||||
| } | ||||
| 
 | ||||
| func (t *Task) SetProgress(percentage int) { | ||||
| func (t *Task[K, V]) SetProgress(percentage int) { | ||||
| 	t.progress = percentage | ||||
| } | ||||
| 
 | ||||
| func (t *Task) run() { | ||||
| func (t *Task[K, V]) run() { | ||||
| 	t.Status = RUNNING | ||||
| 	defer func() { | ||||
| 		if err := recover(); err != nil { | ||||
|  | @ -76,11 +65,11 @@ func (t *Task) run() { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (t *Task) retry() { | ||||
| func (t *Task[K, V]) retry() { | ||||
| 	t.run() | ||||
| } | ||||
| 
 | ||||
| func (t *Task) Cancel() { | ||||
| func (t *Task[K, V]) Cancel() { | ||||
| 	if t.Status == FINISHED || t.Status == CANCELED { | ||||
| 		return | ||||
| 	} | ||||
|  | @ -90,3 +79,11 @@ func (t *Task) Cancel() { | |||
| 	// maybe can't cancel
 | ||||
| 	t.Status = CANCELING | ||||
| } | ||||
| 
 | ||||
| func WithCancelCtx[K comparable, V any](task *Task[K, V]) *Task[K, V] { | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
| 	task.Ctx = ctx | ||||
| 	task.cancel = cancel | ||||
| 	task.Status = PENDING | ||||
| 	return task | ||||
| } | ||||
|  |  | |||
|  | @ -3,16 +3,22 @@ package task | |||
| import ( | ||||
| 	"github.com/alist-org/alist/v3/pkg/utils" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"sync/atomic" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| func TestTask_Manager(t *testing.T) { | ||||
| 	tm := NewTaskManager() | ||||
| 	id := tm.Submit("test", func(task *Task) error { | ||||
| 		time.Sleep(time.Millisecond * 500) | ||||
| 		return nil | ||||
| 	tm := NewTaskManager[uint64, struct{}](3, func(id *uint64) { | ||||
| 		atomic.AddUint64(id, 1) | ||||
| 	}) | ||||
| 	id := tm.Submit(WithCancelCtx(&Task[uint64, struct{}]{ | ||||
| 		Name: "test", | ||||
| 		Func: func(task *Task[uint64, struct{}]) error { | ||||
| 			time.Sleep(time.Millisecond * 500) | ||||
| 			return nil | ||||
| 		}, | ||||
| 	})) | ||||
| 	task, ok := tm.Get(id) | ||||
| 	if !ok { | ||||
| 		t.Fatal("task not found") | ||||
|  | @ -28,16 +34,21 @@ func TestTask_Manager(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestTask_Cancel(t *testing.T) { | ||||
| 	tm := NewTaskManager() | ||||
| 	id := tm.Submit("test", func(task *Task) error { | ||||
| 		for { | ||||
| 			if utils.IsCanceled(task.Ctx) { | ||||
| 				return nil | ||||
| 			} else { | ||||
| 				t.Logf("task is running") | ||||
| 			} | ||||
| 		} | ||||
| 	tm := NewTaskManager[uint64, struct{}](3, func(id *uint64) { | ||||
| 		atomic.AddUint64(id, 1) | ||||
| 	}) | ||||
| 	id := tm.Submit(WithCancelCtx(&Task[uint64, struct{}]{ | ||||
| 		Name: "test", | ||||
| 		Func: func(task *Task[uint64, struct{}]) error { | ||||
| 			for { | ||||
| 				if utils.IsCanceled(task.Ctx) { | ||||
| 					return nil | ||||
| 				} else { | ||||
| 					t.Logf("task is running") | ||||
| 				} | ||||
| 			} | ||||
| 		}, | ||||
| 	})) | ||||
| 	task, ok := tm.Get(id) | ||||
| 	if !ok { | ||||
| 		t.Fatal("task not found") | ||||
|  | @ -51,15 +62,20 @@ func TestTask_Cancel(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestTask_Retry(t *testing.T) { | ||||
| 	tm := NewTaskManager() | ||||
| 	num := 0 | ||||
| 	id := tm.Submit("test", func(task *Task) error { | ||||
| 		num++ | ||||
| 		if num&1 == 1 { | ||||
| 			return errors.New("test error") | ||||
| 		} | ||||
| 		return nil | ||||
| 	tm := NewTaskManager[uint64, struct{}](3, func(id *uint64) { | ||||
| 		atomic.AddUint64(id, 1) | ||||
| 	}) | ||||
| 	num := 0 | ||||
| 	id := tm.Submit(WithCancelCtx(&Task[uint64, struct{}]{ | ||||
| 		Name: "test", | ||||
| 		Func: func(task *Task[uint64, struct{}]) error { | ||||
| 			num++ | ||||
| 			if num&1 == 1 { | ||||
| 				return errors.New("test error") | ||||
| 			} | ||||
| 			return nil | ||||
| 		}, | ||||
| 	})) | ||||
| 	task, ok := tm.Get(id) | ||||
| 	if !ok { | ||||
| 		t.Fatal("task not found") | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Noah Hsu
						Noah Hsu