mirror of https://github.com/Xhofe/alist
refactor(aria2): extract monitor
parent
72208e052a
commit
a6df492fff
|
@ -6,22 +6,16 @@ import (
|
|||
"github.com/alist-org/alist/v3/conf"
|
||||
"github.com/alist-org/alist/v3/internal/driver"
|
||||
"github.com/alist-org/alist/v3/internal/fs"
|
||||
"github.com/alist-org/alist/v3/internal/model"
|
||||
"github.com/alist-org/alist/v3/internal/operations"
|
||||
"github.com/alist-org/alist/v3/pkg/task"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"mime"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
func AddURI(ctx context.Context, uri string, dstPath string, parentPath string) error {
|
||||
func AddURI(ctx context.Context, uri string, dstDirPath string) error {
|
||||
// check account
|
||||
account, actualParentPath, err := operations.GetAccountAndActualPath(parentPath)
|
||||
account, dstDirActualPath, err := operations.GetAccountAndActualPath(dstDirPath)
|
||||
if err != nil {
|
||||
return errors.WithMessage(err, "failed get account")
|
||||
}
|
||||
|
@ -30,7 +24,7 @@ func AddURI(ctx context.Context, uri string, dstPath string, parentPath string)
|
|||
return errors.WithStack(fs.ErrUploadNotSupported)
|
||||
}
|
||||
// check path is valid
|
||||
obj, err := operations.Get(ctx, account, actualParentPath)
|
||||
obj, err := operations.Get(ctx, account, dstDirActualPath)
|
||||
if err != nil {
|
||||
if !errors.Is(errors.Cause(err), driver.ErrorObjectNotFound) {
|
||||
return errors.WithMessage(err, "failed get object")
|
||||
|
@ -51,99 +45,17 @@ func AddURI(ctx context.Context, uri string, dstPath string, parentPath string)
|
|||
return errors.Wrapf(err, "failed to add uri %s", uri)
|
||||
}
|
||||
// TODO add to task manager
|
||||
Aria2TaskManager.Submit(task.WithCancelCtx(&task.Task[string, OfflineDownload]{
|
||||
TaskManager.Submit(task.WithCancelCtx(&task.Task[string, interface{}]{
|
||||
ID: gid,
|
||||
Name: fmt.Sprintf("download %s to [%s](%s)", uri, account.GetAccount().VirtualPath, actualParentPath),
|
||||
Func: func(tsk *task.Task[string, OfflineDownload]) error {
|
||||
defer func() {
|
||||
notify.Signals.Delete(gid)
|
||||
// clear temp dir
|
||||
_ = os.RemoveAll(tempDir)
|
||||
}()
|
||||
c := make(chan int)
|
||||
notify.Signals.Store(gid, c)
|
||||
retried := 0
|
||||
for {
|
||||
select {
|
||||
case <-tsk.Ctx.Done():
|
||||
_, err := client.Remove(gid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case status := <-c:
|
||||
switch status {
|
||||
case Completed:
|
||||
return nil
|
||||
default:
|
||||
info, err := client.TellStatus(gid)
|
||||
if err != nil {
|
||||
retried++
|
||||
}
|
||||
if retried > 5 {
|
||||
return errors.Errorf("failed to get status of %s, retried %d times", gid, retried)
|
||||
}
|
||||
retried = 0
|
||||
if len(info.FollowedBy) != 0 {
|
||||
gid = info.FollowedBy[0]
|
||||
|
||||
}
|
||||
// update download status
|
||||
total, err := strconv.ParseUint(info.TotalLength, 10, 64)
|
||||
if err != nil {
|
||||
total = 0
|
||||
}
|
||||
downloaded, err := strconv.ParseUint(info.CompletedLength, 10, 64)
|
||||
if err != nil {
|
||||
downloaded = 0
|
||||
}
|
||||
tsk.SetProgress(int(float64(downloaded) / float64(total)))
|
||||
switch info.Status {
|
||||
case "complete":
|
||||
// get files
|
||||
files, err := client.GetFiles(gid)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to get files of %s", gid)
|
||||
}
|
||||
// upload files
|
||||
for _, file := range files {
|
||||
size, _ := strconv.ParseUint(file.Length, 10, 64)
|
||||
f, err := os.Open(file.Path)
|
||||
mimetype := mime.TypeByExtension(path.Ext(file.Path))
|
||||
if mimetype == "" {
|
||||
mimetype = "application/octet-stream"
|
||||
}
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to open file %s", file.Path)
|
||||
}
|
||||
stream := model.FileStream{
|
||||
Obj: model.Object{
|
||||
Name: path.Base(file.Path),
|
||||
Size: size,
|
||||
Modified: time.Now(),
|
||||
IsFolder: false,
|
||||
},
|
||||
ReadCloser: f,
|
||||
Mimetype: "",
|
||||
}
|
||||
return operations.Put(tsk.Ctx, account, actualParentPath, stream, tsk.SetProgress)
|
||||
}
|
||||
case "error":
|
||||
return errors.Errorf("failed to download %s, error: %s", gid, info.ErrorMessage)
|
||||
case "active", "waiting", "paused":
|
||||
// do nothing
|
||||
case "removed":
|
||||
return errors.Errorf("failed to download %s, removed", gid)
|
||||
default:
|
||||
return errors.Errorf("failed to download %s, unknown status %s", gid, info.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
Name: fmt.Sprintf("download %s to [%s](%s)", uri, account.GetAccount().VirtualPath, dstDirActualPath),
|
||||
Func: func(tsk *task.Task[string, interface{}]) error {
|
||||
m := &Monitor{
|
||||
tsk: tsk,
|
||||
tempDir: tempDir,
|
||||
retried: 0,
|
||||
dstDirPath: dstDirPath,
|
||||
}
|
||||
},
|
||||
Data: OfflineDownload{
|
||||
Gid: gid,
|
||||
URI: uri,
|
||||
DstPath: dstPath,
|
||||
return m.Loop()
|
||||
},
|
||||
}))
|
||||
return nil
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
var Aria2TaskManager = task.NewTaskManager[string, OfflineDownload](3)
|
||||
var TaskManager = task.NewTaskManager[string, interface{}](3)
|
||||
var notify = NewNotify()
|
||||
var client rpc.Client
|
||||
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
package aria2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/alist-org/alist/v3/internal/model"
|
||||
"github.com/alist-org/alist/v3/internal/operations"
|
||||
"github.com/alist-org/alist/v3/pkg/task"
|
||||
"github.com/pkg/errors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"mime"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Monitor struct {
|
||||
tsk *task.Task[string, interface{}]
|
||||
tempDir string
|
||||
retried int
|
||||
c chan int
|
||||
dstDirPath string
|
||||
}
|
||||
|
||||
func (m *Monitor) Loop() error {
|
||||
defer func() {
|
||||
notify.Signals.Delete(m.tsk.ID)
|
||||
// clear temp dir, should do while complete
|
||||
//_ = os.RemoveAll(m.tempDir)
|
||||
}()
|
||||
m.c = make(chan int)
|
||||
notify.Signals.Store(m.tsk.ID, m.c)
|
||||
for {
|
||||
select {
|
||||
case <-m.tsk.Ctx.Done():
|
||||
_, err := client.Remove(m.tsk.ID)
|
||||
return err
|
||||
case <-m.c:
|
||||
ok, err := m.Update()
|
||||
if ok {
|
||||
return err
|
||||
}
|
||||
case <-time.After(time.Second * 5):
|
||||
ok, err := m.Update()
|
||||
if ok {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Monitor) Update() (bool, error) {
|
||||
info, err := client.TellStatus(m.tsk.ID)
|
||||
if err != nil {
|
||||
m.retried++
|
||||
}
|
||||
if m.retried > 5 {
|
||||
return true, errors.Errorf("failed to get status of %s, retried %d times", m.tsk.ID, m.retried)
|
||||
}
|
||||
m.retried = 0
|
||||
if len(info.FollowedBy) != 0 {
|
||||
gid := info.FollowedBy[0]
|
||||
notify.Signals.Delete(m.tsk.ID)
|
||||
m.tsk.ID = gid
|
||||
notify.Signals.Store(gid, m.c)
|
||||
}
|
||||
// update download status
|
||||
total, err := strconv.ParseUint(info.TotalLength, 10, 64)
|
||||
if err != nil {
|
||||
total = 0
|
||||
}
|
||||
downloaded, err := strconv.ParseUint(info.CompletedLength, 10, 64)
|
||||
if err != nil {
|
||||
downloaded = 0
|
||||
}
|
||||
m.tsk.SetProgress(int(float64(downloaded) / float64(total)))
|
||||
switch info.Status {
|
||||
case "complete":
|
||||
err := m.Complete()
|
||||
return true, errors.WithMessage(err, "failed to transfer file")
|
||||
case "error":
|
||||
return true, errors.Errorf("failed to download %s, error: %s", m.tsk.ID, info.ErrorMessage)
|
||||
case "active", "waiting", "paused":
|
||||
m.tsk.SetStatus("aria2: " + info.Status)
|
||||
return false, nil
|
||||
case "removed":
|
||||
return true, errors.Errorf("failed to download %s, removed", m.tsk.ID)
|
||||
default:
|
||||
return true, errors.Errorf("failed to download %s, unknown status %s", m.tsk.ID, info.Status)
|
||||
}
|
||||
}
|
||||
|
||||
var transferTaskManager = task.NewTaskManager[uint64, interface{}](3, func(k *uint64) {
|
||||
atomic.AddUint64(k, 1)
|
||||
})
|
||||
|
||||
func (m *Monitor) Complete() error {
|
||||
// check dstDir again
|
||||
account, dstDirActualPath, err := operations.GetAccountAndActualPath(m.dstDirPath)
|
||||
if err != nil {
|
||||
return errors.WithMessage(err, "failed get account")
|
||||
}
|
||||
// get files
|
||||
files, err := client.GetFiles(m.tsk.ID)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to get files of %s", m.tsk.ID)
|
||||
}
|
||||
// upload files
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(files))
|
||||
go func() {
|
||||
wg.Wait()
|
||||
err := os.RemoveAll(m.tempDir)
|
||||
if err != nil {
|
||||
log.Errorf("failed to remove aria2 temp dir: %+v", err.Error())
|
||||
}
|
||||
}()
|
||||
for _, file := range files {
|
||||
transferTaskManager.Submit(task.WithCancelCtx[uint64](&task.Task[uint64, interface{}]{
|
||||
Name: fmt.Sprintf("transfer %s to %s", file.Path, m.dstDirPath),
|
||||
Func: func(tsk *task.Task[uint64, interface{}]) error {
|
||||
defer wg.Done()
|
||||
size, _ := strconv.ParseUint(file.Length, 10, 64)
|
||||
mimetype := mime.TypeByExtension(path.Ext(file.Path))
|
||||
if mimetype == "" {
|
||||
mimetype = "application/octet-stream"
|
||||
}
|
||||
f, err := os.Open(file.Path)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to open file %s", file.Path)
|
||||
}
|
||||
stream := model.FileStream{
|
||||
Obj: model.Object{
|
||||
Name: path.Base(file.Path),
|
||||
Size: size,
|
||||
Modified: time.Now(),
|
||||
IsFolder: false,
|
||||
},
|
||||
ReadCloser: f,
|
||||
Mimetype: "",
|
||||
}
|
||||
return operations.Put(tsk.Ctx, account, dstDirActualPath, stream, tsk.SetProgress)
|
||||
},
|
||||
}))
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,7 +0,0 @@
|
|||
package aria2
|
||||
|
||||
type OfflineDownload struct {
|
||||
Gid string
|
||||
DstPath string
|
||||
URI string
|
||||
}
|
Loading…
Reference in New Issue