From 105f22969cf691ee7621b41a144a9f432d446258 Mon Sep 17 00:00:00 2001 From: Noah Hsu Date: Wed, 21 Dec 2022 15:03:09 +0800 Subject: [PATCH] feat: support cancel for some drivers (close #2717) --- drivers/123/driver.go | 6 ++--- drivers/139/driver.go | 5 +++++ drivers/189/driver.go | 2 +- drivers/189/util.go | 13 ++++++++--- drivers/189pc/driver.go | 26 +++++++++++----------- drivers/189pc/meta.go | 2 +- drivers/189pc/utils.go | 34 +++++++++++++---------------- drivers/alist_v3/driver.go | 2 +- drivers/aliyundrive/driver.go | 4 ++++ drivers/aliyundrive_share/driver.go | 31 -------------------------- drivers/baidu_netdisk/driver.go | 9 +++++++- drivers/baidu_photo/driver.go | 3 +++ drivers/ftp/driver.go | 1 + drivers/google_drive/driver.go | 2 +- drivers/google_drive/util.go | 6 ++++- drivers/google_photo/driver.go | 2 +- drivers/lanzou/driver.go | 2 +- drivers/mediatrack/driver.go | 2 +- drivers/mega/driver.go | 4 ++++ drivers/onedrive/driver.go | 2 +- drivers/onedrive/util.go | 8 +++++-- drivers/pikpak/driver.go | 2 +- drivers/pikpak_share/driver.go | 31 -------------------------- drivers/quark/driver.go | 5 ++++- drivers/quark/util.go | 7 +++--- drivers/s3/driver.go | 2 +- drivers/teambition/driver.go | 4 ++-- drivers/teambition/util.go | 25 ++++++++++++++------- drivers/uss/driver.go | 1 + drivers/webdav/driver.go | 1 + drivers/yandex_disk/driver.go | 3 ++- 31 files changed, 118 insertions(+), 129 deletions(-) diff --git a/drivers/123/driver.go b/drivers/123/driver.go index 38546c34..1b300c0e 100644 --- a/drivers/123/driver.go +++ b/drivers/123/driver.go @@ -221,7 +221,7 @@ func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr } var resp UploadResp _, err := d.request("https://www.123pan.com/a/api/file/upload_request", http.MethodPost, func(req *resty.Request) { - req.SetBody(data) + req.SetBody(data).SetContext(ctx) }, &resp) if err != nil { return err @@ -245,14 +245,14 @@ func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr Key: &resp.Data.Key, Body: uploadFile, } - _, err = uploader.Upload(input) + _, err = uploader.UploadWithContext(ctx, input) if err != nil { return err } _, err = d.request("https://www.123pan.com/api/file/upload_complete", http.MethodPost, func(req *resty.Request) { req.SetBody(base.Json{ "fileId": resp.Data.FileId, - }) + }).SetContext(ctx) }, nil) return err } diff --git a/drivers/139/driver.go b/drivers/139/driver.go index 720f6048..4f3650c8 100644 --- a/drivers/139/driver.go +++ b/drivers/139/driver.go @@ -13,6 +13,7 @@ import ( "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" log "github.com/sirupsen/logrus" ) @@ -268,6 +269,9 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr part := int(math.Ceil(float64(stream.GetSize()) / float64(Default))) var start int64 = 0 for i := 0; i < part; i++ { + if utils.IsCanceled(ctx) { + return ctx.Err() + } byteSize := stream.GetSize() - start if byteSize > Default { byteSize = Default @@ -281,6 +285,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr if err != nil { return err } + req = req.WithContext(ctx) headers := map[string]string{ "Accept": "*/*", "Content-Type": "text/plain;name=" + unicode(stream.GetName()), diff --git a/drivers/189/driver.go b/drivers/189/driver.go index c5a88543..6d00870f 100644 --- a/drivers/189/driver.go +++ b/drivers/189/driver.go @@ -194,7 +194,7 @@ func (d *Cloud189) Remove(ctx context.Context, obj model.Obj) error { } func (d *Cloud189) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - return d.newUpload(dstDir, stream, up) + return d.newUpload(ctx, dstDir, stream, up) } var _ driver.Driver = (*Cloud189)(nil) diff --git a/drivers/189/util.go b/drivers/189/util.go index 1bc7a430..6b0ee7bd 100644 --- a/drivers/189/util.go +++ b/drivers/189/util.go @@ -2,6 +2,7 @@ package _189 import ( "bytes" + "context" "crypto/md5" "encoding/base64" "encoding/hex" @@ -306,7 +307,7 @@ func (d *Cloud189) uploadRequest(uri string, form map[string]string, resp interf return data, nil } -func (d *Cloud189) newUpload(dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { +func (d *Cloud189) newUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { sessionKey, err := d.getSessionKey() if err != nil { return err @@ -335,6 +336,9 @@ func (d *Cloud189) newUpload(dstDir model.Obj, file model.FileStreamer, up drive md5s := make([]string, 0) md5Sum := md5.New() for i = 1; i <= count; i++ { + if utils.IsCanceled(ctx) { + return ctx.Err() + } byteSize = file.GetSize() - finish if DEFAULT < byteSize { byteSize = DEFAULT @@ -364,12 +368,15 @@ func (d *Cloud189) newUpload(dstDir model.Obj, file model.FileStreamer, up drive log.Debugf("uploadData: %+v", uploadData) requestURL := uploadData.RequestURL uploadHeaders := strings.Split(decodeURIComponent(uploadData.RequestHeader), "&") - req, _ := http.NewRequest(http.MethodPut, requestURL, bytes.NewReader(byteData)) + req, err := http.NewRequest(http.MethodPut, requestURL, bytes.NewReader(byteData)) + if err != nil { + return err + } + req = req.WithContext(ctx) for _, v := range uploadHeaders { i := strings.Index(v, "=") req.Header.Set(v[0:i], v[i+1:]) } - r, err := base.HttpClient.Do(req) log.Debugf("%+v %+v", r, r.Request.Header) r.Body.Close() diff --git a/drivers/189pc/driver.go b/drivers/189pc/driver.go index 7297057a..e757c200 100644 --- a/drivers/189pc/driver.go +++ b/drivers/189pc/driver.go @@ -13,7 +13,7 @@ import ( "github.com/go-resty/resty/v2" ) -type Yun189PC struct { +type Cloud189PC struct { model.Storage Addition @@ -26,15 +26,15 @@ type Yun189PC struct { tokenInfo *AppSessionResp } -func (y *Yun189PC) Config() driver.Config { +func (y *Cloud189PC) Config() driver.Config { return config } -func (y *Yun189PC) GetAddition() driver.Additional { +func (y *Cloud189PC) GetAddition() driver.Additional { return &y.Addition } -func (y *Yun189PC) Init(ctx context.Context) (err error) { +func (y *Cloud189PC) Init(ctx context.Context) (err error) { // 处理个人云和家庭云参数 if y.isFamily() && y.RootFolderID == "-11" { y.RootFolderID = "" @@ -73,15 +73,15 @@ func (y *Yun189PC) Init(ctx context.Context) (err error) { return } -func (y *Yun189PC) Drop(ctx context.Context) error { +func (y *Cloud189PC) Drop(ctx context.Context) error { return nil } -func (y *Yun189PC) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { +func (y *Cloud189PC) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { return y.getFiles(ctx, dir.GetID()) } -func (y *Yun189PC) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { +func (y *Cloud189PC) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { var downloadUrl struct { URL string `json:"fileDownloadUrl"` } @@ -140,7 +140,7 @@ func (y *Yun189PC) Link(ctx context.Context, file model.Obj, args model.LinkArgs return like, nil } -func (y *Yun189PC) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { +func (y *Cloud189PC) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { fullUrl := API_URL if y.isFamily() { fullUrl += "/family/file" @@ -167,7 +167,7 @@ func (y *Yun189PC) MakeDir(ctx context.Context, parentDir model.Obj, dirName str return err } -func (y *Yun189PC) Move(ctx context.Context, srcObj, dstDir model.Obj) error { +func (y *Cloud189PC) Move(ctx context.Context, srcObj, dstDir model.Obj) error { _, err := y.post(API_URL+"/batch/createBatchTask.action", func(req *resty.Request) { req.SetContext(ctx) req.SetFormData(map[string]string{ @@ -191,7 +191,7 @@ func (y *Yun189PC) Move(ctx context.Context, srcObj, dstDir model.Obj) error { return err } -func (y *Yun189PC) Rename(ctx context.Context, srcObj model.Obj, newName string) error { +func (y *Cloud189PC) Rename(ctx context.Context, srcObj model.Obj, newName string) error { queryParam := make(map[string]string) fullUrl := API_URL method := http.MethodPost @@ -216,7 +216,7 @@ func (y *Yun189PC) Rename(ctx context.Context, srcObj model.Obj, newName string) return err } -func (y *Yun189PC) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { +func (y *Cloud189PC) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { _, err := y.post(API_URL+"/batch/createBatchTask.action", func(req *resty.Request) { req.SetContext(ctx) req.SetFormData(map[string]string{ @@ -241,7 +241,7 @@ func (y *Yun189PC) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { return err } -func (y *Yun189PC) Remove(ctx context.Context, obj model.Obj) error { +func (y *Cloud189PC) Remove(ctx context.Context, obj model.Obj) error { _, err := y.post(API_URL+"/batch/createBatchTask.action", func(req *resty.Request) { req.SetContext(ctx) req.SetFormData(map[string]string{ @@ -265,7 +265,7 @@ func (y *Yun189PC) Remove(ctx context.Context, obj model.Obj) error { return err } -func (y *Yun189PC) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { +func (y *Cloud189PC) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { if y.RapidUpload { return y.FastUpload(ctx, dstDir, stream, up) } diff --git a/drivers/189pc/meta.go b/drivers/189pc/meta.go index c270041c..3cf535a7 100644 --- a/drivers/189pc/meta.go +++ b/drivers/189pc/meta.go @@ -25,6 +25,6 @@ var config = driver.Config{ func init() { op.RegisterDriver(func() driver.Driver { - return &Yun189PC{} + return &Cloud189PC{} }) } diff --git a/drivers/189pc/utils.go b/drivers/189pc/utils.go index 7a8739ed..eab497dd 100644 --- a/drivers/189pc/utils.go +++ b/drivers/189pc/utils.go @@ -47,7 +47,7 @@ const ( CHANNEL_ID = "web_cloud.189.cn" ) -func (y *Yun189PC) request(url, method string, callback base.ReqCallback, params Params, resp interface{}) ([]byte, error) { +func (y *Cloud189PC) request(url, method string, callback base.ReqCallback, params Params, resp interface{}) ([]byte, error) { dateOfGmt := getHttpDateStr() sessionKey := y.tokenInfo.SessionKey sessionSecret := y.tokenInfo.SessionSecret @@ -124,15 +124,15 @@ func (y *Yun189PC) request(url, method string, callback base.ReqCallback, params } } -func (y *Yun189PC) get(url string, callback base.ReqCallback, resp interface{}) ([]byte, error) { +func (y *Cloud189PC) get(url string, callback base.ReqCallback, resp interface{}) ([]byte, error) { return y.request(url, http.MethodGet, callback, nil, resp) } -func (y *Yun189PC) post(url string, callback base.ReqCallback, resp interface{}) ([]byte, error) { +func (y *Cloud189PC) post(url string, callback base.ReqCallback, resp interface{}) ([]byte, error) { return y.request(url, http.MethodPost, callback, nil, resp) } -func (y *Yun189PC) getFiles(ctx context.Context, fileId string) ([]model.Obj, error) { +func (y *Cloud189PC) getFiles(ctx context.Context, fileId string) ([]model.Obj, error) { fullUrl := API_URL if y.isFamily() { fullUrl += "/family/file" @@ -184,7 +184,7 @@ func (y *Yun189PC) getFiles(ctx context.Context, fileId string) ([]model.Obj, er return res, nil } -func (y *Yun189PC) login() (err error) { +func (y *Cloud189PC) login() (err error) { // 初始化登陆所需参数 if y.loginParam == nil || !y.NoUseOcr { if err = y.initLoginParam(); err != nil { @@ -264,7 +264,7 @@ func (y *Yun189PC) login() (err error) { /* 初始化登陆需要的参数 * 如果遇到验证码返回错误 */ -func (y *Yun189PC) initLoginParam() error { +func (y *Cloud189PC) initLoginParam() error { // 清除cookie jar, _ := cookiejar.New(nil) y.client.SetCookieJar(jar) @@ -335,7 +335,7 @@ func (y *Yun189PC) initLoginParam() error { } // 刷新会话 -func (y *Yun189PC) refreshSession() (err error) { +func (y *Cloud189PC) refreshSession() (err error) { var erron RespErr var userSessionResp UserSessionResp _, err = y.client.R(). @@ -381,7 +381,7 @@ func (y *Yun189PC) refreshSession() (err error) { } // 普通上传 -func (y *Yun189PC) CommonUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) (err error) { +func (y *Cloud189PC) CommonUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) (err error) { const DEFAULT int64 = 10485760 var count = int64(math.Ceil(float64(file.GetSize()) / float64(DEFAULT))) @@ -418,10 +418,8 @@ func (y *Yun189PC) CommonUpload(ctx context.Context, dstDir model.Obj, file mode silceMd5Hexs := make([]string, 0, count) byteData := bytes.NewBuffer(make([]byte, DEFAULT)) for i := int64(1); i <= count; i++ { - select { - case <-ctx.Done(): + if utils.IsCanceled(ctx) { return ctx.Err() - default: } // 读取块 @@ -491,7 +489,7 @@ func (y *Yun189PC) CommonUpload(ctx context.Context, dstDir model.Obj, file mode } // 快传 -func (y *Yun189PC) FastUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) (err error) { +func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) (err error) { // 需要获取完整文件md5,必须支持 io.Seek tempFile, err := utils.CreateTempFile(file.GetReadCloser()) if err != nil { @@ -511,10 +509,8 @@ func (y *Yun189PC) FastUpload(ctx context.Context, dstDir model.Obj, file model. silceMd5Hexs := make([]string, 0, count) silceMd5Base64s := make([]string, 0, count) for i := 1; i <= count; i++ { - select { - case <-ctx.Done(): + if utils.IsCanceled(ctx) { return ctx.Err() - default: } silceMd5.Reset() @@ -616,11 +612,11 @@ func (y *Yun189PC) FastUpload(ctx context.Context, dstDir model.Obj, file model. return err } -func (y *Yun189PC) isFamily() bool { +func (y *Cloud189PC) isFamily() bool { return y.Type == "family" } -func (y *Yun189PC) isLogin() bool { +func (y *Cloud189PC) isLogin() bool { if y.tokenInfo == nil { return false } @@ -629,7 +625,7 @@ func (y *Yun189PC) isLogin() bool { } // 获取家庭云所有用户信息 -func (y *Yun189PC) getFamilyInfoList() ([]FamilyInfoResp, error) { +func (y *Cloud189PC) getFamilyInfoList() ([]FamilyInfoResp, error) { var resp FamilyInfoListResp _, err := y.get(API_URL+"/family/manage/getFamilyList.action", nil, &resp) if err != nil { @@ -639,7 +635,7 @@ func (y *Yun189PC) getFamilyInfoList() ([]FamilyInfoResp, error) { } // 抽取家庭云ID -func (y *Yun189PC) getFamilyID() (string, error) { +func (y *Cloud189PC) getFamilyID() (string, error) { infos, err := y.getFamilyInfoList() if err != nil { return "", err diff --git a/drivers/alist_v3/driver.go b/drivers/alist_v3/driver.go index 7e6fb088..7de65a30 100644 --- a/drivers/alist_v3/driver.go +++ b/drivers/alist_v3/driver.go @@ -162,7 +162,7 @@ func (d *AListV3) Put(ctx context.Context, dstDir model.Obj, stream model.FileSt if err != nil { return nil } - _, err = base.RestyClient.R(). + _, err = base.RestyClient.R().SetContext(ctx). SetResult(&resp). SetHeader("Authorization", d.AccessToken). SetHeader("File-Path", path.Join(dstDir.GetPath(), stream.GetName())). diff --git a/drivers/aliyundrive/driver.go b/drivers/aliyundrive/driver.go index 98bec318..696e553e 100644 --- a/drivers/aliyundrive/driver.go +++ b/drivers/aliyundrive/driver.go @@ -248,10 +248,14 @@ func (d *AliDrive) Put(ctx context.Context, dstDir model.Obj, stream model.FileS } for i, partInfo := range resp.PartInfoList { + if utils.IsCanceled(ctx) { + return ctx.Err() + } req, err := http.NewRequest("PUT", partInfo.UploadUrl, io.LimitReader(file, DEFAULT)) if err != nil { return err } + req = req.WithContext(ctx) res, err := base.HttpClient.Do(req) if err != nil { return err diff --git a/drivers/aliyundrive_share/driver.go b/drivers/aliyundrive_share/driver.go index b3a083c6..9d920683 100644 --- a/drivers/aliyundrive_share/driver.go +++ b/drivers/aliyundrive_share/driver.go @@ -8,7 +8,6 @@ import ( "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/driver" - "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/pkg/cron" "github.com/alist-org/alist/v3/pkg/utils" @@ -113,34 +112,4 @@ func (d *AliyundriveShare) Link(ctx context.Context, file model.Obj, args model. }, nil } -func (d *AliyundriveShare) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { - // TODO create folder - return errs.NotSupport -} - -func (d *AliyundriveShare) Move(ctx context.Context, srcObj, dstDir model.Obj) error { - // TODO move obj - return errs.NotSupport -} - -func (d *AliyundriveShare) Rename(ctx context.Context, srcObj model.Obj, newName string) error { - // TODO rename obj - return errs.NotSupport -} - -func (d *AliyundriveShare) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { - // TODO copy obj - return errs.NotSupport -} - -func (d *AliyundriveShare) Remove(ctx context.Context, obj model.Obj) error { - // TODO remove obj - return errs.NotSupport -} - -func (d *AliyundriveShare) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - // TODO upload file - return errs.NotSupport -} - var _ driver.Driver = (*AliyundriveShare)(nil) diff --git a/drivers/baidu_netdisk/driver.go b/drivers/baidu_netdisk/driver.go index aedb05bd..ae67f459 100644 --- a/drivers/baidu_netdisk/driver.go +++ b/drivers/baidu_netdisk/driver.go @@ -192,6 +192,9 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F } left = stream.GetSize() for i, partseq := range precreateResp.BlockList { + if utils.IsCanceled(ctx) { + return ctx.Err() + } byteSize := Default var byteData []byte if left < Default { @@ -207,7 +210,11 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F } u := "https://d.pcs.baidu.com/rest/2.0/pcs/superfile2" params["partseq"] = strconv.Itoa(partseq) - res, err := base.RestyClient.R().SetQueryParams(params).SetFileReader("file", stream.GetName(), bytes.NewReader(byteData)).Post(u) + res, err := base.RestyClient.R(). + SetContext(ctx). + SetQueryParams(params). + SetFileReader("file", stream.GetName(), bytes.NewReader(byteData)). + Post(u) if err != nil { return err } diff --git a/drivers/baidu_photo/driver.go b/drivers/baidu_photo/driver.go index c9c85f27..828e6ec6 100644 --- a/drivers/baidu_photo/driver.go +++ b/drivers/baidu_photo/driver.go @@ -240,6 +240,9 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil } for i := 0; i < count; i++ { + if utils.IsCanceled(ctx) { + return ctx.Err() + } uploadParams["partseq"] = fmt.Sprint(i) _, err = d.Post("https://c3.pcs.baidu.com/rest/2.0/pcs/superfile2", func(r *resty.Request) { r.SetContext(ctx) diff --git a/drivers/ftp/driver.go b/drivers/ftp/driver.go index f3c9b43d..682d7ef4 100644 --- a/drivers/ftp/driver.go +++ b/drivers/ftp/driver.go @@ -113,6 +113,7 @@ func (d *FTP) Put(ctx context.Context, dstDir model.Obj, stream model.FileStream if err := d.login(); err != nil { return err } + // TODO: support cancel return d.conn.Stor(stdpath.Join(dstDir.GetPath(), stream.GetName()), stream) } diff --git a/drivers/google_drive/driver.go b/drivers/google_drive/driver.go index dc2d9a43..277566eb 100644 --- a/drivers/google_drive/driver.go +++ b/drivers/google_drive/driver.go @@ -134,7 +134,7 @@ func (d *GoogleDrive) Put(ctx context.Context, dstDir model.Obj, stream model.Fi "X-Upload-Content-Type": stream.GetMimetype(), "X-Upload-Content-Length": strconv.FormatInt(stream.GetSize(), 10), }). - SetError(&e).SetBody(data) + SetError(&e).SetBody(data).SetContext(ctx) if obj != nil { res, err = req.Patch(url) } else { diff --git a/drivers/google_drive/util.go b/drivers/google_drive/util.go index 5a23d8ed..38d9dd6e 100644 --- a/drivers/google_drive/util.go +++ b/drivers/google_drive/util.go @@ -9,6 +9,7 @@ import ( "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" "github.com/go-resty/resty/v2" log "github.com/sirupsen/logrus" ) @@ -104,6 +105,9 @@ func (d *GoogleDrive) chunkUpload(ctx context.Context, stream model.FileStreamer var defaultChunkSize = d.ChunkSize * 1024 * 1024 var finish int64 = 0 for finish < stream.GetSize() { + if utils.IsCanceled(ctx) { + return ctx.Err() + } chunkSize := stream.GetSize() - finish if chunkSize > defaultChunkSize { chunkSize = defaultChunkSize @@ -112,7 +116,7 @@ func (d *GoogleDrive) chunkUpload(ctx context.Context, stream model.FileStreamer req.SetHeaders(map[string]string{ "Content-Length": strconv.FormatInt(chunkSize, 10), "Content-Range": fmt.Sprintf("bytes %d-%d/%d", finish, finish+chunkSize-1, stream.GetSize()), - }).SetBody(io.LimitReader(stream.GetReadCloser(), chunkSize)) + }).SetBody(io.LimitReader(stream.GetReadCloser(), chunkSize)).SetContext(ctx) }, nil) if err != nil { return err diff --git a/drivers/google_photo/driver.go b/drivers/google_photo/driver.go index f2d5acea..aab3b5d9 100644 --- a/drivers/google_photo/driver.go +++ b/drivers/google_photo/driver.go @@ -124,7 +124,7 @@ func (d *GooglePhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fi } resp, err := d.request(postUrl, http.MethodPost, func(req *resty.Request) { - req.SetBody(stream.GetReadCloser()) + req.SetBody(stream.GetReadCloser()).SetContext(ctx) }, nil, postHeaders) if err != nil { diff --git a/drivers/lanzou/driver.go b/drivers/lanzou/driver.go index aaedaf69..1ca451ce 100644 --- a/drivers/lanzou/driver.go +++ b/drivers/lanzou/driver.go @@ -157,7 +157,7 @@ func (d *LanZou) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr "id": "WU_FILE_0", "name": stream.GetName(), "folder_id": dstDir.GetID(), - }).SetFileReader("upload_file", stream.GetName(), stream) + }).SetFileReader("upload_file", stream.GetName(), stream).SetContext(ctx) }, nil, true) return err } diff --git a/drivers/mediatrack/driver.go b/drivers/mediatrack/driver.go index 2246aaa2..c9937505 100644 --- a/drivers/mediatrack/driver.go +++ b/drivers/mediatrack/driver.go @@ -195,7 +195,7 @@ func (d *MediaTrack) Put(ctx context.Context, dstDir model.Obj, stream model.Fil Key: &resp.Data.Object, Body: tempFile, } - _, err = uploader.Upload(input) + _, err = uploader.UploadWithContext(ctx, input) if err != nil { return err } diff --git a/drivers/mega/driver.go b/drivers/mega/driver.go index ad0cfe6d..c202a313 100644 --- a/drivers/mega/driver.go +++ b/drivers/mega/driver.go @@ -9,6 +9,7 @@ import ( "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" log "github.com/sirupsen/logrus" "github.com/t3rm1n4l/go-mega" ) @@ -155,6 +156,9 @@ func (d *Mega) Put(ctx context.Context, dstDir model.Obj, stream model.FileStrea } for id := 0; id < u.Chunks(); id++ { + if utils.IsCanceled(ctx) { + return ctx.Err() + } _, chkSize, err := u.ChunkLocation(id) if err != nil { return err diff --git a/drivers/onedrive/driver.go b/drivers/onedrive/driver.go index 14d4f69f..65dfa075 100644 --- a/drivers/onedrive/driver.go +++ b/drivers/onedrive/driver.go @@ -137,7 +137,7 @@ func (d *Onedrive) Remove(ctx context.Context, obj model.Obj) error { func (d *Onedrive) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { var err error if stream.GetSize() <= 4*1024*1024 { - err = d.upSmall(dstDir, stream) + err = d.upSmall(ctx, dstDir, stream) } else { err = d.upBig(ctx, dstDir, stream, up) } diff --git a/drivers/onedrive/util.go b/drivers/onedrive/util.go index 32cbfce9..289266b6 100644 --- a/drivers/onedrive/util.go +++ b/drivers/onedrive/util.go @@ -147,14 +147,14 @@ func (d *Onedrive) GetFile(path string) (*File, error) { return &file, err } -func (d *Onedrive) upSmall(dstDir model.Obj, stream model.FileStreamer) error { +func (d *Onedrive) upSmall(ctx context.Context, dstDir model.Obj, stream model.FileStreamer) error { url := d.GetMetaUrl(false, stdpath.Join(dstDir.GetPath(), stream.GetName())) + "/content" data, err := io.ReadAll(stream) if err != nil { return err } _, err = d.Request(url, http.MethodPut, func(req *resty.Request) { - req.SetBody(data) + req.SetBody(data).SetContext(ctx) }, nil) return err } @@ -185,6 +185,10 @@ func (d *Onedrive) upBig(ctx context.Context, dstDir model.Obj, stream model.Fil return err } req, err := http.NewRequest("PUT", uploadUrl, bytes.NewBuffer(byteData)) + if err != nil { + return err + } + req = req.WithContext(ctx) req.Header.Set("Content-Length", strconv.Itoa(int(byteSize))) req.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", finish, finish+byteSize-1, stream.GetSize())) finish += byteSize diff --git a/drivers/pikpak/driver.go b/drivers/pikpak/driver.go index f08a4067..88a4803c 100644 --- a/drivers/pikpak/driver.go +++ b/drivers/pikpak/driver.go @@ -189,7 +189,7 @@ func (d *PikPak) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr Key: &key, Body: tempFile, } - _, err = uploader.Upload(input) + _, err = uploader.UploadWithContext(ctx, input) return err } diff --git a/drivers/pikpak_share/driver.go b/drivers/pikpak_share/driver.go index 8c625c1c..f476900e 100644 --- a/drivers/pikpak_share/driver.go +++ b/drivers/pikpak_share/driver.go @@ -6,7 +6,6 @@ import ( "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/model" - "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/pkg/utils" "github.com/go-resty/resty/v2" log "github.com/sirupsen/logrus" @@ -79,34 +78,4 @@ func (d *PikPakShare) Link(ctx context.Context, file model.Obj, args model.LinkA return &link, nil } -func (d *PikPakShare) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { - // TODO create folder - return errs.NotSupport -} - -func (d *PikPakShare) Move(ctx context.Context, srcObj, dstDir model.Obj) error { - // TODO move obj - return errs.NotSupport -} - -func (d *PikPakShare) Rename(ctx context.Context, srcObj model.Obj, newName string) error { - // TODO rename obj - return errs.NotSupport -} - -func (d *PikPakShare) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { - // TODO copy obj - return errs.NotSupport -} - -func (d *PikPakShare) Remove(ctx context.Context, obj model.Obj) error { - // TODO remove obj - return errs.NotSupport -} - -func (d *PikPakShare) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - // TODO upload file - return errs.NotSupport -} - var _ driver.Driver = (*PikPakShare)(nil) diff --git a/drivers/quark/driver.go b/drivers/quark/driver.go index fb3731fc..599ed4f5 100644 --- a/drivers/quark/driver.go +++ b/drivers/quark/driver.go @@ -179,6 +179,9 @@ func (d *Quark) Put(ctx context.Context, dstDir model.Obj, stream model.FileStre partNumber := 1 sizeDivide100 := stream.GetSize() / 100 for left > 0 { + if utils.IsCanceled(ctx) { + return ctx.Err() + } if left > int64(partSize) { bytes = defaultBytes } else { @@ -190,7 +193,7 @@ func (d *Quark) Put(ctx context.Context, dstDir model.Obj, stream model.FileStre } left -= int64(partSize) log.Debugf("left: %d", left) - m, err := d.upPart(pre, stream.GetMimetype(), partNumber, bytes) + m, err := d.upPart(ctx, pre, stream.GetMimetype(), partNumber, bytes) //m, err := driver.UpPart(pre, file.GetMIMEType(), partNumber, bytes, account, md5Str, sha1Str) if err != nil { return err diff --git a/drivers/quark/util.go b/drivers/quark/util.go index 0627c904..50f1eb8d 100644 --- a/drivers/quark/util.go +++ b/drivers/quark/util.go @@ -1,6 +1,7 @@ package quark import ( + "context" "crypto/md5" "encoding/base64" "errors" @@ -118,7 +119,7 @@ func (d *Quark) upHash(md5, sha1, taskId string) (bool, error) { return resp.Data.Finish, err } -func (d *Quark) upPart(pre UpPreResp, mineType string, partNumber int, bytes []byte) (string, error) { +func (d *Quark) upPart(ctx context.Context, pre UpPreResp, mineType string, partNumber int, bytes []byte) (string, error) { //func (driver Quark) UpPart(pre UpPreResp, mineType string, partNumber int, bytes []byte, account *model.Account, md5Str, sha1Str string) (string, error) { timeStr := time.Now().UTC().Format(http.TimeFormat) data := base.Json{ @@ -135,7 +136,7 @@ x-oss-user-agent:aliyun-sdk-js/6.6.1 Chrome 98.0.4758.80 on Windows 10 64-bit } var resp UpAuthResp _, err := d.request("/file/upload/auth", http.MethodPost, func(req *resty.Request) { - req.SetBody(data) + req.SetBody(data).SetContext(ctx) }, &resp) if err != nil { return "", err @@ -150,7 +151,7 @@ x-oss-user-agent:aliyun-sdk-js/6.6.1 Chrome 98.0.4758.80 on Windows 10 64-bit // } //} u := fmt.Sprintf("https://%s.%s/%s", pre.Data.Bucket, pre.Data.UploadUrl[7:], pre.Data.ObjKey) - res, err := base.RestyClient.R(). + res, err := base.RestyClient.R().SetContext(ctx). SetHeaders(map[string]string{ "Authorization": resp.Data.AuthKey, "Content-Type": mineType, diff --git a/drivers/s3/driver.go b/drivers/s3/driver.go index 9ab3c4c8..bdcf7ab5 100644 --- a/drivers/s3/driver.go +++ b/drivers/s3/driver.go @@ -134,7 +134,7 @@ func (d *S3) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreame Key: &key, Body: stream, } - _, err := uploader.Upload(input) + _, err := uploader.UploadWithContext(ctx, input) return err } diff --git a/drivers/teambition/driver.go b/drivers/teambition/driver.go index 62855f0c..f2dfaa3c 100644 --- a/drivers/teambition/driver.go +++ b/drivers/teambition/driver.go @@ -132,11 +132,11 @@ func (d *Teambition) Put(ctx context.Context, dstDir model.Obj, stream model.Fil var newFile *FileUpload if stream.GetSize() <= 20971520 { // post upload - newFile, err = d.upload(stream, token) + newFile, err = d.upload(ctx, stream, token) } else { // chunk upload //err = base.ErrNotImplement - newFile, err = d.chunkUpload(stream, token, up) + newFile, err = d.chunkUpload(ctx, stream, token, up) } if err != nil { return err diff --git a/drivers/teambition/util.go b/drivers/teambition/util.go index 4f5fd27f..e95ce30d 100644 --- a/drivers/teambition/util.go +++ b/drivers/teambition/util.go @@ -1,6 +1,7 @@ package teambition import ( + "context" "errors" "fmt" "io" @@ -12,6 +13,7 @@ import ( "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" "github.com/go-resty/resty/v2" log "github.com/sirupsen/logrus" ) @@ -115,13 +117,15 @@ func (d *Teambition) getFiles(parentId string) ([]model.Obj, error) { return files, nil } -func (d *Teambition) upload(file model.FileStreamer, token string) (*FileUpload, error) { +func (d *Teambition) upload(ctx context.Context, file model.FileStreamer, token string) (*FileUpload, error) { prefix := "tcs" if d.isInternational() { prefix = "us-tcs" } var newFile FileUpload - _, err := base.RestyClient.R().SetResult(&newFile).SetHeader("Authorization", token). + _, err := base.RestyClient.R(). + SetContext(ctx). + SetResult(&newFile).SetHeader("Authorization", token). SetMultipartFormData(map[string]string{ "name": file.GetName(), "type": file.GetMimetype(), @@ -135,7 +139,7 @@ func (d *Teambition) upload(file model.FileStreamer, token string) (*FileUpload, return &newFile, nil } -func (d *Teambition) chunkUpload(file model.FileStreamer, token string, up driver.UpdateProgress) (*FileUpload, error) { +func (d *Teambition) chunkUpload(ctx context.Context, file model.FileStreamer, token string, up driver.UpdateProgress) (*FileUpload, error) { prefix := "tcs" referer := "https://www.teambition.com/" if d.isInternational() { @@ -153,6 +157,9 @@ func (d *Teambition) chunkUpload(file model.FileStreamer, token string, up drive return nil, err } for i := 0; i < newChunk.Chunks; i++ { + if utils.IsCanceled(ctx) { + return nil, ctx.Err() + } chunkSize := newChunk.ChunkSize if i == newChunk.Chunks-1 { chunkSize = int(file.GetSize()) - i*chunkSize @@ -166,11 +173,13 @@ func (d *Teambition) chunkUpload(file model.FileStreamer, token string, up drive u := fmt.Sprintf("https://%s.teambition.net/upload/chunk/%s?chunk=%d&chunks=%d", prefix, newChunk.FileKey, i+1, newChunk.Chunks) log.Debugf("url: %s", u) - _, err := base.RestyClient.R().SetHeaders(map[string]string{ - "Authorization": token, - "Content-Type": "application/octet-stream", - "Referer": referer, - }).SetBody(chunkData).Post(u) + _, err := base.RestyClient.R(). + SetContext(ctx). + SetHeaders(map[string]string{ + "Authorization": token, + "Content-Type": "application/octet-stream", + "Referer": referer, + }).SetBody(chunkData).Post(u) if err != nil { return nil, err } diff --git a/drivers/uss/driver.go b/drivers/uss/driver.go index bce18b46..56c04f82 100644 --- a/drivers/uss/driver.go +++ b/drivers/uss/driver.go @@ -123,6 +123,7 @@ func (d *USS) Remove(ctx context.Context, obj model.Obj) error { } func (d *USS) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + // TODO not support cancel?? return d.client.Put(&upyun.PutObjectConfig{ Path: getKey(path.Join(dstDir.GetPath(), stream.GetName()), false), Reader: stream, diff --git a/drivers/webdav/driver.go b/drivers/webdav/driver.go index 7c29bcb8..0f9340ab 100644 --- a/drivers/webdav/driver.go +++ b/drivers/webdav/driver.go @@ -98,6 +98,7 @@ func (d *WebDav) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr r.Header.Set("Content-Type", stream.GetMimetype()) r.ContentLength = stream.GetSize() } + // TODO: support cancel err := d.client.WriteStream(path.Join(dstDir.GetPath(), stream.GetName()), stream, 0644, callback) return err } diff --git a/drivers/yandex_disk/driver.go b/drivers/yandex_disk/driver.go index ddccf057..5af9f2e4 100644 --- a/drivers/yandex_disk/driver.go +++ b/drivers/yandex_disk/driver.go @@ -121,10 +121,11 @@ func (d *YandexDisk) Put(ctx context.Context, dstDir model.Obj, stream model.Fil if err != nil { return err } + req = req.WithContext(ctx) req.Header.Set("Content-Length", strconv.FormatInt(stream.GetSize(), 10)) req.Header.Set("Content-Type", "application/octet-stream") res, err := base.HttpClient.Do(req) - res.Body.Close() + _ = res.Body.Close() return err }