Cloudreve/pkg/filemanager/workflows/remote_download.go

647 lines
19 KiB
Go

package workflows
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/cloudreve/Cloudreve/v4/application/dependency"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/ent/task"
"github.com/cloudreve/Cloudreve/v4/inventory"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/cluster"
"github.com/cloudreve/Cloudreve/v4/pkg/downloader"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager"
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
"github.com/cloudreve/Cloudreve/v4/pkg/queue"
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
"github.com/samber/lo"
)
type (
RemoteDownloadTask struct {
*queue.DBTask
l logging.Logger
state *RemoteDownloadTaskState
node cluster.Node
d downloader.Downloader
progress queue.Progresses
}
RemoteDownloadTaskPhase string
RemoteDownloadTaskState struct {
SrcFileUri string `json:"src_file_uri,omitempty"`
SrcUri string `json:"src_uri,omitempty"`
Dst string `json:"dst,omitempty"`
Handle *downloader.TaskHandle `json:"handle,omitempty"`
Status *downloader.TaskStatus `json:"status,omitempty"`
NodeState `json:",inline"`
Phase RemoteDownloadTaskPhase `json:"phase,omitempty"`
SlaveUploadTaskID int `json:"slave__upload_task_id,omitempty"`
SlaveUploadState *SlaveUploadTaskState `json:"slave_upload_state,omitempty"`
GetTaskStatusTried int `json:"get_task_status_tried,omitempty"`
Transferred map[int]interface{} `json:"transferred,omitempty"`
Failed int `json:"failed,omitempty"`
}
)
const (
RemoteDownloadTaskPhaseNotStarted RemoteDownloadTaskPhase = ""
RemoteDownloadTaskPhaseMonitor = "monitor"
RemoteDownloadTaskPhaseTransfer = "transfer"
RemoteDownloadTaskPhaseAwaitSeeding = "seeding"
GetTaskStatusMaxTries = 5
SummaryKeyDownloadStatus = "download"
SummaryKeySrcStr = "src_str"
ProgressTypeRelocateTransferCount = "relocate"
ProgressTypeUploadSinglePrefix = "upload_single_"
SummaryKeySrcMultiple = "src_multiple"
SummaryKeySrcDstPolicyID = "dst_policy_id"
SummaryKeyFailed = "failed"
)
func init() {
queue.RegisterResumableTaskFactory(queue.RemoteDownloadTaskType, NewRemoteDownloadTaskFromModel)
}
// NewRemoteDownloadTask creates a new RemoteDownloadTask
func NewRemoteDownloadTask(ctx context.Context, src string, srcFile, dst string) (queue.Task, error) {
state := &RemoteDownloadTaskState{
SrcUri: src,
SrcFileUri: srcFile,
Dst: dst,
NodeState: NodeState{},
}
stateBytes, err := json.Marshal(state)
if err != nil {
return nil, fmt.Errorf("failed to marshal state: %w", err)
}
t := &RemoteDownloadTask{
DBTask: &queue.DBTask{
Task: &ent.Task{
Type: queue.RemoteDownloadTaskType,
CorrelationID: logging.CorrelationID(ctx),
PrivateState: string(stateBytes),
PublicState: &types.TaskPublicState{},
},
DirectOwner: inventory.UserFromContext(ctx),
},
}
return t, nil
}
func NewRemoteDownloadTaskFromModel(task *ent.Task) queue.Task {
return &RemoteDownloadTask{
DBTask: &queue.DBTask{
Task: task,
},
}
}
func (m *RemoteDownloadTask) Do(ctx context.Context) (task.Status, error) {
dep := dependency.FromContext(ctx)
m.l = dep.Logger()
// unmarshal state
state := &RemoteDownloadTaskState{}
if err := json.Unmarshal([]byte(m.State()), state); err != nil {
return task.StatusError, fmt.Errorf("failed to unmarshal state: %w", err)
}
m.state = state
// select node
node, err := allocateNode(ctx, dep, &m.state.NodeState, types.NodeCapabilityRemoteDownload)
if err != nil {
return task.StatusError, fmt.Errorf("failed to allocate node: %w", err)
}
m.node = node
// create downloader instance
if m.d == nil {
d, err := node.CreateDownloader(ctx, dep.RequestClient(), dep.SettingProvider())
if err != nil {
return task.StatusError, fmt.Errorf("failed to create downloader: %w", err)
}
m.d = d
}
next := task.StatusCompleted
switch m.state.Phase {
case RemoteDownloadTaskPhaseNotStarted:
next, err = m.createDownloadTask(ctx, dep)
case RemoteDownloadTaskPhaseMonitor, RemoteDownloadTaskPhaseAwaitSeeding:
next, err = m.monitor(ctx, dep)
case RemoteDownloadTaskPhaseTransfer:
if m.node.IsMaster() {
next, err = m.masterTransfer(ctx, dep)
} else {
next, err = m.slaveTransfer(ctx, dep)
}
}
newStateStr, marshalErr := json.Marshal(m.state)
if marshalErr != nil {
return task.StatusError, fmt.Errorf("failed to marshal state: %w", marshalErr)
}
m.Lock()
m.Task.PrivateState = string(newStateStr)
m.Unlock()
return next, err
}
func (m *RemoteDownloadTask) createDownloadTask(ctx context.Context, dep dependency.Dep) (task.Status, error) {
if m.state.Handle != nil {
m.state.Phase = RemoteDownloadTaskPhaseMonitor
return task.StatusSuspending, nil
}
user := inventory.UserFromContext(ctx)
torrentUrl := m.state.SrcUri
if m.state.SrcFileUri != "" {
// Target is a torrent file
uri, err := fs.NewUriFromString(m.state.SrcFileUri)
if err != nil {
return task.StatusError, fmt.Errorf("failed to parse src file uri: %s (%w)", err, queue.CriticalErr)
}
fm := manager.NewFileManager(dep, user)
expire := time.Now().Add(dep.SettingProvider().EntityUrlValidDuration(ctx))
torrentUrls, _, err := fm.GetEntityUrls(ctx, []manager.GetEntityUrlArgs{
{URI: uri},
}, fs.WithUrlExpire(&expire))
if err != nil {
return task.StatusError, fmt.Errorf("failed to get torrent entity urls: %w", err)
}
if len(torrentUrls) == 0 {
return task.StatusError, fmt.Errorf("no torrent urls found")
}
torrentUrl = torrentUrls[0].Url
}
// Create download task
handle, err := m.d.CreateTask(ctx, torrentUrl, user.Edges.Group.Settings.RemoteDownloadOptions)
if err != nil {
return task.StatusError, fmt.Errorf("failed to create download task: %w", err)
}
m.state.Handle = handle
m.state.Phase = RemoteDownloadTaskPhaseMonitor
return task.StatusSuspending, nil
}
func (m *RemoteDownloadTask) monitor(ctx context.Context, dep dependency.Dep) (task.Status, error) {
resumeAfter := time.Duration(m.node.Settings(ctx).Interval) * time.Second
// Update task status
status, err := m.d.Info(ctx, m.state.Handle)
if err != nil {
if errors.Is(err, downloader.ErrTaskNotFount) && m.state.Status != nil {
// If task is not found, but it previously existed, consider it as canceled
m.l.Warning("task not found, consider it as canceled")
return task.StatusCanceled, nil
}
m.state.GetTaskStatusTried++
if m.state.GetTaskStatusTried >= GetTaskStatusMaxTries {
return task.StatusError, fmt.Errorf("failed to get task status after %d retry: %w", m.state.GetTaskStatusTried, err)
}
m.l.Warning("failed to get task info: %s, will retry.", err)
m.ResumeAfter(resumeAfter)
return task.StatusSuspending, nil
}
// Follow to new handle if needed
if status.FollowedBy != nil {
m.l.Info("Task handle updated to %v", status.FollowedBy)
m.state.Handle = status.FollowedBy
m.ResumeAfter(0)
return task.StatusSuspending, nil
}
if m.state.Status == nil || m.state.Status.Total != status.Total {
m.l.Info("download size changed, re-validate files.")
// First time to get status / total size changed, check user capacity
if err := m.validateFiles(ctx, dep, status); err != nil {
m.state.Status = status
return task.StatusError, fmt.Errorf("failed to validate files: %s (%w)", err, queue.CriticalErr)
}
}
m.state.Status = status
m.state.GetTaskStatusTried = 0
m.l.Debug("Monitor %q task state: %s", status.Name, status.State)
switch status.State {
case downloader.StatusSeeding:
m.l.Info("Download task seeding")
if m.state.Phase == RemoteDownloadTaskPhaseMonitor {
// Not transferred
m.state.Phase = RemoteDownloadTaskPhaseTransfer
return task.StatusSuspending, nil
} else if !m.node.Settings(ctx).WaitForSeeding {
// Skip seeding
m.l.Info("Download task seeding skipped.")
return task.StatusCompleted, nil
} else {
// Still seeding
m.ResumeAfter(resumeAfter)
return task.StatusSuspending, nil
}
case downloader.StatusCompleted:
m.l.Info("Download task completed")
if m.state.Phase == RemoteDownloadTaskPhaseMonitor {
// Not transferred
m.state.Phase = RemoteDownloadTaskPhaseTransfer
return task.StatusSuspending, nil
}
// Seeding complete
m.l.Info("Download task seeding completed")
return task.StatusCompleted, nil
case downloader.StatusDownloading:
m.ResumeAfter(resumeAfter)
return task.StatusSuspending, nil
case downloader.StatusUnknown, downloader.StatusError:
return task.StatusError, fmt.Errorf("download task failed with state %q (%w), errorMsg: %s", status.State, queue.CriticalErr, status.ErrorMessage)
}
m.ResumeAfter(resumeAfter)
return task.StatusSuspending, nil
}
func (m *RemoteDownloadTask) slaveTransfer(ctx context.Context, dep dependency.Dep) (task.Status, error) {
u := inventory.UserFromContext(ctx)
if m.state.Transferred == nil {
m.state.Transferred = make(map[int]interface{})
}
if m.state.SlaveUploadTaskID == 0 {
dstUri, err := fs.NewUriFromString(m.state.Dst)
if err != nil {
return task.StatusError, fmt.Errorf("failed to parse dst uri %q: %s (%w)", m.state.Dst, err, queue.CriticalErr)
}
// Create slave upload task
payload := &SlaveUploadTaskState{
Files: []SlaveUploadEntity{},
MaxParallel: dep.SettingProvider().MaxParallelTransfer(ctx),
UserID: u.ID,
}
// Construct files to be transferred
for _, f := range m.state.Status.Files {
if !f.Selected {
continue
}
// Skip already transferred
if _, ok := m.state.Transferred[f.Index]; ok {
continue
}
dst := dstUri.JoinRaw(sanitizeFileName(f.Name))
src := path.Join(m.state.Status.SavePath, f.Name)
payload.Files = append(payload.Files, SlaveUploadEntity{
Src: src,
Uri: dst,
Size: f.Size,
Index: f.Index,
})
}
payloadStr, err := json.Marshal(payload)
if err != nil {
return task.StatusError, fmt.Errorf("failed to marshal payload: %w", err)
}
taskId, err := m.node.CreateTask(ctx, queue.SlaveUploadTaskType, string(payloadStr))
if err != nil {
return task.StatusError, fmt.Errorf("failed to create slave task: %w", err)
}
m.state.NodeState.progress = nil
m.state.SlaveUploadTaskID = taskId
m.ResumeAfter(0)
return task.StatusSuspending, nil
}
m.l.Info("Checking slave upload task %d...", m.state.SlaveUploadTaskID)
t, err := m.node.GetTask(ctx, m.state.SlaveUploadTaskID, true)
if err != nil {
return task.StatusError, fmt.Errorf("failed to get slave task: %w", err)
}
m.Lock()
m.state.NodeState.progress = t.Progress
m.Unlock()
m.state.SlaveUploadState = &SlaveUploadTaskState{}
if err := json.Unmarshal([]byte(t.PrivateState), m.state.SlaveUploadState); err != nil {
return task.StatusError, fmt.Errorf("failed to unmarshal slave compress state: %s (%w)", err, queue.CriticalErr)
}
if t.Status == task.StatusError || t.Status == task.StatusCompleted {
if len(m.state.SlaveUploadState.Transferred) < len(m.state.SlaveUploadState.Files) {
// Not all files transferred, retry
slaveTaskId := m.state.SlaveUploadTaskID
m.state.SlaveUploadTaskID = 0
for i, _ := range m.state.SlaveUploadState.Transferred {
m.state.Transferred[m.state.SlaveUploadState.Files[i].Index] = struct{}{}
}
m.l.Warning("Slave task %d failed to transfer %d files, retrying...", slaveTaskId, len(m.state.SlaveUploadState.Files)-len(m.state.SlaveUploadState.Transferred))
return task.StatusError, fmt.Errorf(
"slave task failed to transfer %d files, first 5 errors: %s",
len(m.state.SlaveUploadState.Files)-len(m.state.SlaveUploadState.Transferred),
m.state.SlaveUploadState.First5TransferErrors,
)
} else {
m.state.Phase = RemoteDownloadTaskPhaseAwaitSeeding
m.ResumeAfter(0)
return task.StatusSuspending, nil
}
}
if t.Status == task.StatusCanceled {
return task.StatusError, fmt.Errorf("slave task canceled (%w)", queue.CriticalErr)
}
m.l.Info("Slave task %d is still uploading, resume after 30s.", m.state.SlaveUploadTaskID)
m.ResumeAfter(time.Second * 30)
return task.StatusSuspending, nil
}
func (m *RemoteDownloadTask) masterTransfer(ctx context.Context, dep dependency.Dep) (task.Status, error) {
if m.state.Transferred == nil {
m.state.Transferred = make(map[int]interface{})
}
maxParallel := dep.SettingProvider().MaxParallelTransfer(ctx)
wg := sync.WaitGroup{}
worker := make(chan int, maxParallel)
for i := 0; i < maxParallel; i++ {
worker <- i
}
// Sum up total count and select files
totalCount := 0
totalSize := int64(0)
allFiles := make([]downloader.TaskFile, 0, len(m.state.Status.Files))
for _, f := range m.state.Status.Files {
if f.Selected {
allFiles = append(allFiles, f)
totalSize += f.Size
totalCount++
}
}
m.Lock()
m.progress = make(queue.Progresses)
m.progress[ProgressTypeUploadCount] = &queue.Progress{Total: int64(totalCount)}
m.progress[ProgressTypeUpload] = &queue.Progress{Total: totalSize}
m.Unlock()
dstUri, err := fs.NewUriFromString(m.state.Dst)
if err != nil {
return task.StatusError, fmt.Errorf("failed to parse dst uri: %s (%w)", err, queue.CriticalErr)
}
user := inventory.UserFromContext(ctx)
fm := manager.NewFileManager(dep, user)
failed := int64(0)
ae := serializer.NewAggregateError()
transferFunc := func(workerId int, file downloader.TaskFile) {
defer func() {
atomic.AddInt64(&m.progress[ProgressTypeUploadCount].Current, 1)
worker <- workerId
wg.Done()
}()
sanitizedName := sanitizeFileName(file.Name)
dst := dstUri.JoinRaw(sanitizedName)
src := filepath.FromSlash(path.Join(m.state.Status.SavePath, file.Name))
m.l.Info("Uploading file %s to %s...", src, sanitizedName, dst)
progressKey := fmt.Sprintf("%s%d", ProgressTypeUploadSinglePrefix, workerId)
m.Lock()
m.progress[progressKey] = &queue.Progress{Identifier: dst.String(), Total: file.Size}
m.Unlock()
fileStream, err := os.Open(src)
if err != nil {
m.l.Warning("Failed to open file %s: %s", src, err.Error())
atomic.AddInt64(&m.progress[ProgressTypeUpload].Current, file.Size)
atomic.AddInt64(&failed, 1)
ae.Add(file.Name, fmt.Errorf("failed to open file: %w", err))
return
}
defer fileStream.Close()
fileData := &fs.UploadRequest{
Props: &fs.UploadProps{
Uri: dst,
Size: file.Size,
},
ProgressFunc: func(current, diff int64, total int64) {
atomic.AddInt64(&m.progress[progressKey].Current, diff)
atomic.AddInt64(&m.progress[ProgressTypeUpload].Current, diff)
},
File: fileStream,
}
_, err = fm.Update(ctx, fileData, fs.WithNoEntityType())
if err != nil {
m.l.Warning("Failed to upload file %s: %s", src, err.Error())
atomic.AddInt64(&failed, 1)
atomic.AddInt64(&m.progress[ProgressTypeUpload].Current, file.Size)
ae.Add(file.Name, fmt.Errorf("failed to upload file: %w", err))
return
}
m.Lock()
m.state.Transferred[file.Index] = nil
m.Unlock()
}
// Start upload files
for _, file := range allFiles {
// Check if file is already transferred
if _, ok := m.state.Transferred[file.Index]; ok {
m.l.Info("File %s already transferred, skipping...", file.Name)
atomic.AddInt64(&m.progress[ProgressTypeUpload].Current, file.Size)
atomic.AddInt64(&m.progress[ProgressTypeUploadCount].Current, 1)
continue
}
select {
case <-ctx.Done():
return task.StatusError, ctx.Err()
case workerId := <-worker:
wg.Add(1)
go transferFunc(workerId, file)
}
}
wg.Wait()
if failed > 0 {
m.state.Failed = int(failed)
m.l.Error("Failed to transfer %d file(s).", failed)
return task.StatusError, fmt.Errorf("failed to transfer %d file(s), first 5 errors: %s", failed, ae.FormatFirstN(5))
}
m.l.Info("All files transferred.")
m.state.Phase = RemoteDownloadTaskPhaseAwaitSeeding
return task.StatusSuspending, nil
}
func (m *RemoteDownloadTask) awaitSeeding(ctx context.Context, dep dependency.Dep) (task.Status, error) {
return task.StatusSuspending, nil
}
func (m *RemoteDownloadTask) validateFiles(ctx context.Context, dep dependency.Dep, status *downloader.TaskStatus) error {
// Validate files
user := inventory.UserFromContext(ctx)
fm := manager.NewFileManager(dep, user)
dstUri, err := fs.NewUriFromString(m.state.Dst)
if err != nil {
return fmt.Errorf("failed to parse dst uri: %w", err)
}
selectedFiles := lo.Filter(status.Files, func(f downloader.TaskFile, _ int) bool {
return f.Selected
})
if len(selectedFiles) == 0 {
return fmt.Errorf("no selected file found in download task")
}
validateArgs := lo.Map(selectedFiles, func(f downloader.TaskFile, _ int) fs.PreValidateFile {
return fs.PreValidateFile{
Name: sanitizeFileName(f.Name),
Size: f.Size,
OmitName: f.Name == "",
}
})
if err := fm.PreValidateUpload(ctx, dstUri, validateArgs...); err != nil {
return fmt.Errorf("failed to pre-validate files: %w", err)
}
return nil
}
func (m *RemoteDownloadTask) Cleanup(ctx context.Context) error {
if m.state.Handle != nil {
if err := m.d.Cancel(ctx, m.state.Handle); err != nil {
m.l.Warning("failed to cancel download task: %s", err)
}
}
if m.state.Status != nil && m.node.IsMaster() && m.state.Status.SavePath != "" {
if err := os.RemoveAll(m.state.Status.SavePath); err != nil {
m.l.Warning("failed to remove download temp folder: %s", err)
}
}
return nil
}
// SetDownloadTarget sets the files to download for the task
func (m *RemoteDownloadTask) SetDownloadTarget(ctx context.Context, args ...*downloader.SetFileToDownloadArgs) error {
if m.state.Handle == nil {
return fmt.Errorf("download task not created")
}
return m.d.SetFilesToDownload(ctx, m.state.Handle, args...)
}
// CancelDownload cancels the download task
func (m *RemoteDownloadTask) CancelDownload(ctx context.Context) error {
if m.state.Handle == nil {
return nil
}
return m.d.Cancel(ctx, m.state.Handle)
}
func (m *RemoteDownloadTask) Summarize(hasher hashid.Encoder) *queue.Summary {
// unmarshal state
if m.state == nil {
if err := json.Unmarshal([]byte(m.State()), &m.state); err != nil {
return nil
}
}
var status *downloader.TaskStatus
if m.state.Status != nil {
status = &*m.state.Status
// Redact save path
status.SavePath = ""
}
failed := m.state.Failed
if m.state.SlaveUploadState != nil && m.state.Phase != RemoteDownloadTaskPhaseTransfer {
failed = len(m.state.SlaveUploadState.Files) - len(m.state.SlaveUploadState.Transferred)
}
return &queue.Summary{
Phase: string(m.state.Phase),
NodeID: m.state.NodeID,
Props: map[string]any{
SummaryKeySrcStr: m.state.SrcUri,
SummaryKeySrc: m.state.SrcFileUri,
SummaryKeyDst: m.state.Dst,
SummaryKeyFailed: failed,
SummaryKeyDownloadStatus: status,
},
}
}
func (m *RemoteDownloadTask) Progress(ctx context.Context) queue.Progresses {
m.Lock()
defer m.Unlock()
if m.state.NodeState.progress != nil {
merged := make(queue.Progresses)
for k, v := range m.progress {
merged[k] = v
}
for k, v := range m.state.NodeState.progress {
merged[k] = v
}
return merged
}
return m.progress
}
func sanitizeFileName(name string) string {
r := strings.NewReplacer("\\", "_", ":", "_", "*", "_", "?", "_", "\"", "_", "<", "_", ">", "_", "|", "_")
return r.Replace(name)
}