feat: add copy to task manager

refactor/fs
Noah Hsu 2022-06-17 21:23:44 +08:00
parent 53e969e894
commit fa6e918fc7
10 changed files with 143 additions and 70 deletions

View File

@ -21,8 +21,7 @@ func (d Driver) Config() driver.Config {
func (d *Driver) Init(ctx context.Context, account model.Account) error { func (d *Driver) Init(ctx context.Context, account model.Account) error {
d.Account = account d.Account = account
addition := d.Account.Addition err := utils.Json.UnmarshalFromString(d.Account.Addition, &d.Addition)
err := utils.Json.UnmarshalFromString(addition, &d.Addition)
if err != nil { if err != nil {
return errors.Wrap(err, "error while unmarshal addition") return errors.Wrap(err, "error while unmarshal addition")
} }
@ -32,7 +31,7 @@ func (d *Driver) Init(ctx context.Context, account model.Account) error {
} else { } else {
d.SetStatus("OK") d.SetStatus("OK")
} }
operations.SaveDriverAccount(d) operations.MustSaveDriverAccount(d)
return err return err
} }
@ -79,7 +78,7 @@ func (d *Driver) Remove(ctx context.Context, obj model.Obj) error {
panic("implement me") panic("implement me")
} }
func (d *Driver) Put(ctx context.Context, parentDir model.Obj, stream model.FileStreamer) error { func (d *Driver) Put(ctx context.Context, parentDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")
} }

View File

@ -50,5 +50,7 @@ type Writer interface {
// Remove remove `object` // Remove remove `object`
Remove(ctx context.Context, obj model.Obj) error Remove(ctx context.Context, obj model.Obj) error
// Put upload `stream` to `parentDir` // Put upload `stream` to `parentDir`
Put(ctx context.Context, parentDir model.Obj, stream model.FileStreamer) error Put(ctx context.Context, parentDir model.Obj, stream model.FileStreamer, up UpdateProgress) error
} }
type UpdateProgress func(percentage float64)

View File

@ -2,34 +2,55 @@ package fs
import ( import (
"context" "context"
"fmt"
"github.com/alist-org/alist/v3/pkg/task"
"github.com/alist-org/alist/v3/pkg/utils"
stdpath "path" stdpath "path"
"github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/operations" "github.com/alist-org/alist/v3/internal/operations"
"github.com/alist-org/alist/v3/internal/task"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
var copyTaskManager = task.NewTaskManager() var copyTaskManager = task.NewTaskManager()
func CopyBetween2Accounts(ctx context.Context, srcAccount, dstAccount driver.Driver, srcPath, dstPath string) error { func CopyBetween2Accounts(ctx context.Context, srcAccount, dstAccount driver.Driver, srcPath, dstPath string, setStatus func(status string)) error {
srcFile, err := operations.Get(ctx, srcAccount, srcPath) setStatus("getting src object")
srcObj, err := operations.Get(ctx, srcAccount, srcPath)
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed get src [%s] file", srcPath) return errors.WithMessagef(err, "failed get src [%s] file", srcPath)
} }
if srcFile.IsDir() { if srcObj.IsDir() {
setStatus("src object is dir, listing files")
files, err := operations.List(ctx, srcAccount, srcPath) files, err := operations.List(ctx, srcAccount, srcPath)
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed list src [%s] files", srcPath) return errors.WithMessagef(err, "failed list src [%s] files", srcPath)
} }
for _, file := range files { for _, file := range files {
if utils.IsCanceled(ctx) {
return nil
}
srcFilePath := stdpath.Join(srcPath, file.GetName()) srcFilePath := stdpath.Join(srcPath, file.GetName())
dstFilePath := stdpath.Join(dstPath, file.GetName()) dstFilePath := stdpath.Join(dstPath, file.GetName())
if err := CopyBetween2Accounts(ctx, srcAccount, dstAccount, srcFilePath, dstFilePath); err != nil { copyTaskManager.Add(fmt.Sprintf("copy %s to %s", srcFilePath, dstFilePath), func(task *task.Task) error {
return errors.WithMessagef(err, "failed copy file [%s] to [%s]", srcFilePath, dstFilePath) return CopyBetween2Accounts(ctx, srcAccount, dstAccount, srcFilePath, dstFilePath, task.SetStatus)
})
} }
} else {
copyTaskManager.Add(fmt.Sprintf("copy %s to %s", srcPath, dstPath), func(task *task.Task) error {
return CopyFileBetween2Accounts(task.Ctx, srcAccount, dstAccount, srcPath, dstPath, func(percentage float64) {
task.SetStatus(fmt.Sprintf("uploading: %2.f%", percentage))
})
})
} }
return nil
}
func CopyFileBetween2Accounts(ctx context.Context, srcAccount, dstAccount driver.Driver, srcPath, dstPath string, up driver.UpdateProgress) error {
srcFile, err := operations.Get(ctx, srcAccount, srcPath)
if err != nil {
return errors.WithMessagef(err, "failed get src [%s] file", srcPath)
} }
link, err := operations.Link(ctx, srcAccount, srcPath, model.LinkArgs{}) link, err := operations.Link(ctx, srcAccount, srcPath, model.LinkArgs{})
if err != nil { if err != nil {
@ -39,6 +60,5 @@ func CopyBetween2Accounts(ctx context.Context, srcAccount, dstAccount driver.Dri
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed get [%s] stream", srcPath) return errors.WithMessagef(err, "failed get [%s] stream", srcPath)
} }
// TODO add as task return operations.Put(ctx, dstAccount, dstPath, stream, up)
return operations.Put(ctx, dstAccount, dstPath, stream)
} }

View File

@ -3,11 +3,10 @@ package fs
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/operations" "github.com/alist-org/alist/v3/internal/operations"
"github.com/alist-org/alist/v3/internal/task" "github.com/alist-org/alist/v3/pkg/task"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -49,7 +48,7 @@ func Copy(ctx context.Context, account driver.Driver, srcPath, dstPath string) (
if err != nil { if err != nil {
return false, errors.WithMessage(err, "failed get src account") return false, errors.WithMessage(err, "failed get src account")
} }
dstAccount, dstActualPath, err := operations.GetAccountAndActualPath(srcPath) dstAccount, dstActualPath, err := operations.GetAccountAndActualPath(dstPath)
if err != nil { if err != nil {
return false, errors.WithMessage(err, "failed get dst account") return false, errors.WithMessage(err, "failed get dst account")
} }
@ -60,7 +59,7 @@ func Copy(ctx context.Context, account driver.Driver, srcPath, dstPath string) (
// not in an account // not in an account
// TODO add status set callback to put // TODO add status set callback to put
copyTaskManager.Add(fmt.Sprintf("copy %s to %s", srcActualPath, dstActualPath), func(task *task.Task) error { copyTaskManager.Add(fmt.Sprintf("copy %s to %s", srcActualPath, dstActualPath), func(task *task.Task) error {
return CopyBetween2Accounts(context.TODO(), srcAccount, dstAccount, srcActualPath, dstActualPath) return CopyBetween2Accounts(task.Ctx, srcAccount, dstAccount, srcActualPath, dstActualPath, task.SetStatus)
}) })
return true, nil return true, nil
} }
@ -73,10 +72,11 @@ func Remove(ctx context.Context, account driver.Driver, path string) error {
return operations.Remove(ctx, account, actualPath) return operations.Remove(ctx, account, actualPath)
} }
// Put add as a put task
func Put(ctx context.Context, account driver.Driver, parentPath string, file model.FileStreamer) error { func Put(ctx context.Context, account driver.Driver, parentPath string, file model.FileStreamer) error {
account, actualParentPath, err := operations.GetAccountAndActualPath(parentPath) account, actualParentPath, err := operations.GetAccountAndActualPath(parentPath)
if err != nil { if err != nil {
return errors.WithMessage(err, "failed get account") return errors.WithMessage(err, "failed get account")
} }
return operations.Put(ctx, account, actualParentPath, file) return operations.Put(ctx, account, actualParentPath, file, nil)
} }

View File

@ -2,6 +2,7 @@ package operations
import ( import (
"context" "context"
log "github.com/sirupsen/logrus"
"sort" "sort"
"strings" "strings"
"time" "time"
@ -85,8 +86,15 @@ func UpdateAccount(ctx context.Context, account model.Account) error {
return nil return nil
} }
// SaveDriverAccount call from specific driver // MustSaveDriverAccount call from specific driver
func SaveDriverAccount(driver driver.Driver) error { func MustSaveDriverAccount(driver driver.Driver) {
err := saveDriverAccount(driver)
if err != nil {
log.Errorf("failed save driver account: %s", err)
}
}
func saveDriverAccount(driver driver.Driver) error {
account := driver.GetAccount() account := driver.GetAccount()
addition := driver.GetAddition() addition := driver.GetAddition()
bytes, err := utils.Json.Marshal(addition) bytes, err := utils.Json.Marshal(addition)

View File

@ -182,7 +182,7 @@ func Remove(ctx context.Context, account driver.Driver, path string) error {
return account.Remove(ctx, obj) return account.Remove(ctx, obj)
} }
func Put(ctx context.Context, account driver.Driver, parentPath string, file model.FileStreamer) error { func Put(ctx context.Context, account driver.Driver, parentPath string, file model.FileStreamer, up driver.UpdateProgress) error {
err := MakeDir(ctx, account, parentPath) err := MakeDir(ctx, account, parentPath)
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed to make dir [%s]", parentPath) return errors.WithMessagef(err, "failed to make dir [%s]", parentPath)
@ -192,5 +192,9 @@ func Put(ctx context.Context, account driver.Driver, parentPath string, file mod
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed to get dir [%s]", parentPath) return errors.WithMessagef(err, "failed to get dir [%s]", parentPath)
} }
return account.Put(ctx, parentDir, file) // if up is nil, set a default to prevent panic
if up == nil {
up = func(p float64) {}
}
return account.Put(ctx, parentDir, file, up)
} }

View File

@ -1,36 +0,0 @@
// manage task, such as file upload, file copy between accounts, offline download, etc.
package task
type Func func(task *Task) error
var (
PENDING = "pending"
RUNNING = "running"
FINISHED = "finished"
)
type Task struct {
ID int64
Name string
Status string
Error error
Func Func
}
func NewTask(name string, func_ Func) *Task {
return &Task{
Name: name,
Status: PENDING,
Func: func_,
}
}
func (t *Task) SetStatus(status string) {
t.Status = status
}
func (t *Task) Run() {
t.Status = RUNNING
t.Error = t.Func(t)
t.Status = FINISHED
}

View File

@ -6,37 +6,37 @@ import (
"github.com/alist-org/alist/v3/pkg/generic_sync" "github.com/alist-org/alist/v3/pkg/generic_sync"
) )
func NewTaskManager() *TaskManager { func NewTaskManager() *Manager {
return &TaskManager{ return &Manager{
tasks: generic_sync.MapOf[int64, *Task]{}, tasks: generic_sync.MapOf[int64, *Task]{},
curID: 0, curID: 0,
} }
} }
type TaskManager struct { type Manager struct {
curID int64 curID int64
tasks generic_sync.MapOf[int64, *Task] tasks generic_sync.MapOf[int64, *Task]
} }
func (tm *TaskManager) AddTask(task *Task) { func (tm *Manager) AddTask(task *Task) {
task.ID = tm.curID task.ID = tm.curID
atomic.AddInt64(&tm.curID, 1) atomic.AddInt64(&tm.curID, 1)
tm.tasks.Store(task.ID, task) tm.tasks.Store(task.ID, task)
} }
func (tm *TaskManager) GetAll() []*Task { func (tm *Manager) GetAll() []*Task {
return tm.tasks.Values() return tm.tasks.Values()
} }
func (tm *TaskManager) Get(id int64) (*Task, bool) { func (tm *Manager) Get(id int64) (*Task, bool) {
return tm.tasks.Load(id) return tm.tasks.Load(id)
} }
func (tm *TaskManager) Remove(id int64) { func (tm *Manager) Remove(id int64) {
tm.tasks.Delete(id) tm.tasks.Delete(id)
} }
func (tm *TaskManager) RemoveFinished() { func (tm *Manager) RemoveFinished() {
tasks := tm.GetAll() tasks := tm.GetAll()
for _, task := range tasks { for _, task := range tasks {
if task.Status == FINISHED { if task.Status == FINISHED {
@ -45,7 +45,7 @@ func (tm *TaskManager) RemoveFinished() {
} }
} }
func (tm *TaskManager) RemoveError() { func (tm *Manager) RemoveError() {
tasks := tm.GetAll() tasks := tm.GetAll()
for _, task := range tasks { for _, task := range tasks {
if task.Error != nil { if task.Error != nil {
@ -54,8 +54,8 @@ func (tm *TaskManager) RemoveError() {
} }
} }
func (tm *TaskManager) Add(name string, f Func) { func (tm *Manager) Add(name string, f Func) {
task := NewTask(name, f) task := newTask(name, f)
tm.AddTask(task) tm.AddTask(task)
go task.Run() go task.Run()
} }

64
pkg/task/task.go Normal file
View File

@ -0,0 +1,64 @@
// Package task manage task, such as file upload, file copy between accounts, offline download, etc.
package task
import (
"context"
"github.com/pkg/errors"
)
var (
PENDING = "pending"
RUNNING = "running"
FINISHED = "finished"
CANCELING = "canceling"
CANCELED = "canceled"
)
type Func func(task *Task) error
type Task struct {
ID int64
Name string
Status string
Error error
Func Func
Ctx context.Context
cancel context.CancelFunc
}
func newTask(name string, func_ Func) *Task {
ctx, cancel := context.WithCancel(context.Background())
return &Task{
Name: name,
Status: PENDING,
Func: func_,
Ctx: ctx,
cancel: cancel,
}
}
func (t *Task) SetStatus(status string) {
t.Status = status
}
func (t *Task) Run() {
t.Status = RUNNING
t.Error = t.Func(t)
if errors.Is(t.Ctx.Err(), context.Canceled) {
t.Status = CANCELED
} else {
t.Status = FINISHED
}
}
func (t *Task) Retry() {
t.Run()
}
func (t *Task) Cancel() {
if t.cancel != nil {
t.cancel()
}
// maybe can't cancel
t.Status = CANCELING
}

12
pkg/utils/ctx.go Normal file
View File

@ -0,0 +1,12 @@
package utils
import "context"
func IsCanceled(ctx context.Context) bool {
select {
case <-ctx.Done():
return true
default:
return false
}
}