diff --git a/cmd/common.go b/cmd/common.go index 47a25f3f..8a73f9b0 100644 --- a/cmd/common.go +++ b/cmd/common.go @@ -17,6 +17,7 @@ func Init() { bootstrap.Log() bootstrap.InitDB() data.InitData() + bootstrap.InitStreamLimit() bootstrap.InitIndex() bootstrap.InitUpgradePatch() } diff --git a/drivers/115/util.go b/drivers/115/util.go index 4d3cdd93..7298f565 100644 --- a/drivers/115/util.go +++ b/drivers/115/util.go @@ -8,8 +8,6 @@ import ( "encoding/hex" "encoding/json" "fmt" - "github.com/alist-org/alist/v3/internal/driver" - "github.com/alist-org/alist/v3/internal/stream" "io" "net/http" "net/url" @@ -20,6 +18,7 @@ import ( "time" "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/pkg/http_range" "github.com/alist-org/alist/v3/pkg/utils" @@ -144,7 +143,7 @@ func (d *Pan115) DownloadWithUA(pickCode, ua string) (*driver115.DownloadInfo, e return nil, err } - bytes, err := crypto.Decode(string(result.EncodedData), key) + b, err := crypto.Decode(string(result.EncodedData), key) if err != nil { return nil, err } @@ -152,7 +151,7 @@ func (d *Pan115) DownloadWithUA(pickCode, ua string) (*driver115.DownloadInfo, e downloadInfo := struct { Url string `json:"url"` }{} - if err := utils.Json.Unmarshal(bytes, &downloadInfo); err != nil { + if err := utils.Json.Unmarshal(b, &downloadInfo); err != nil { return nil, err } @@ -290,13 +289,10 @@ func (c *Pan115) UploadByOSS(ctx context.Context, params *driver115.UploadOSSPar } var bodyBytes []byte - r := &stream.ReaderWithCtx{ - Reader: &stream.ReaderUpdatingProgress{ - Reader: s, - UpdateProgress: up, - }, - Ctx: ctx, - } + r := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, + }) if err = bucket.PutObject(params.Object, r, append( driver115.OssOption(params, ossToken), oss.CallbackResult(&bodyBytes), @@ -405,16 +401,12 @@ func (d *Pan115) UploadByMultipart(ctx context.Context, params *driver115.Upload } default: } - buf := make([]byte, chunk.Size) if _, err = tmpF.ReadAt(buf, chunk.Offset); err != nil && !errors.Is(err, io.EOF) { continue } - - if part, err = bucket.UploadPart(imur, &stream.ReaderWithCtx{ - Reader: bytes.NewBuffer(buf), - Ctx: ctx, - }, chunk.Size, chunk.Number, driver115.OssOption(params, ossToken)...); err == nil { + if part, err = bucket.UploadPart(imur, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(buf)), + chunk.Size, chunk.Number, driver115.OssOption(params, ossToken)...); err == nil { break } } diff --git a/drivers/123/driver.go b/drivers/123/driver.go index 1bf71ae6..7d457138 100644 --- a/drivers/123/driver.go +++ b/drivers/123/driver.go @@ -6,7 +6,6 @@ import ( "encoding/base64" "encoding/hex" "fmt" - "github.com/alist-org/alist/v3/internal/stream" "io" "net/http" "net/url" @@ -249,10 +248,10 @@ func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, file model.FileStrea input := &s3manager.UploadInput{ Bucket: &resp.Data.Bucket, Key: &resp.Data.Key, - Body: &stream.ReaderUpdatingProgress{ + Body: driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ Reader: file, UpdateProgress: up, - }, + }), } _, err = uploader.UploadWithContext(ctx, input) if err != nil { diff --git a/drivers/123/upload.go b/drivers/123/upload.go index a472df55..dc148c4c 100644 --- a/drivers/123/upload.go +++ b/drivers/123/upload.go @@ -81,6 +81,7 @@ func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.Fi batchSize = 10 getS3UploadUrl = d.getS3PreSignedUrls } + limited := driver.NewLimitedUploadStream(ctx, file) for i := 1; i <= chunkCount; i += batchSize { if utils.IsCanceled(ctx) { return ctx.Err() @@ -103,7 +104,7 @@ func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.Fi if j == chunkCount { curSize = file.GetSize() - (int64(chunkCount)-1)*chunkSize } - err = d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, j, end, io.LimitReader(file, chunkSize), curSize, false, getS3UploadUrl) + err = d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, j, end, io.LimitReader(limited, chunkSize), curSize, false, getS3UploadUrl) if err != nil { return err } diff --git a/drivers/139/driver.go b/drivers/139/driver.go index 1e2ba9c4..c6b30335 100644 --- a/drivers/139/driver.go +++ b/drivers/139/driver.go @@ -631,12 +631,13 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr // Progress p := driver.NewProgress(stream.GetSize(), up) + rateLimited := driver.NewLimitedUploadStream(ctx, stream) // 上传所有分片 for _, uploadPartInfo := range uploadPartInfos { index := uploadPartInfo.PartNumber - 1 partSize := partInfos[index].PartSize log.Debugf("[139] uploading part %+v/%+v", index, len(uploadPartInfos)) - limitReader := io.LimitReader(stream, partSize) + limitReader := io.LimitReader(rateLimited, partSize) // Update Progress r := io.TeeReader(limitReader, p) @@ -787,6 +788,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr if part == 0 { part = 1 } + rateLimited := driver.NewLimitedUploadStream(ctx, stream) for i := int64(0); i < part; i++ { if utils.IsCanceled(ctx) { return ctx.Err() @@ -798,7 +800,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr byteSize = partSize } - limitReader := io.LimitReader(stream, byteSize) + limitReader := io.LimitReader(rateLimited, byteSize) // Update Progress r := io.TeeReader(limitReader, p) req, err := http.NewRequest("POST", resp.Data.UploadResult.RedirectionURL, r) diff --git a/drivers/189/util.go b/drivers/189/util.go index 0b4c0633..16a5aa39 100644 --- a/drivers/189/util.go +++ b/drivers/189/util.go @@ -365,7 +365,7 @@ func (d *Cloud189) newUpload(ctx context.Context, dstDir model.Obj, file model.F log.Debugf("uploadData: %+v", uploadData) requestURL := uploadData.RequestURL uploadHeaders := strings.Split(decodeURIComponent(uploadData.RequestHeader), "&") - req, err := http.NewRequest(http.MethodPut, requestURL, bytes.NewReader(byteData)) + req, err := http.NewRequest(http.MethodPut, requestURL, driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) if err != nil { return err } @@ -375,11 +375,11 @@ func (d *Cloud189) newUpload(ctx context.Context, dstDir model.Obj, file model.F req.Header.Set(v[0:i], v[i+1:]) } r, err := base.HttpClient.Do(req) - log.Debugf("%+v %+v", r, r.Request.Header) - r.Body.Close() if err != nil { return err } + log.Debugf("%+v %+v", r, r.Request.Header) + _ = r.Body.Close() up(float64(i) * 100 / float64(count)) } fileMd5 := hex.EncodeToString(md5Sum.Sum(nil)) diff --git a/drivers/189pc/utils.go b/drivers/189pc/utils.go index 6f3c4dcf..290d2e56 100644 --- a/drivers/189pc/utils.go +++ b/drivers/189pc/utils.go @@ -19,6 +19,8 @@ import ( "strings" "time" + "golang.org/x/sync/semaphore" + "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/driver" @@ -174,8 +176,8 @@ func (y *Cloud189PC) put(ctx context.Context, url string, headers map[string]str } var erron RespErr - jsoniter.Unmarshal(body, &erron) - xml.Unmarshal(body, &erron) + _ = jsoniter.Unmarshal(body, &erron) + _ = xml.Unmarshal(body, &erron) if erron.HasError() { return nil, &erron } @@ -508,6 +510,7 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo retry.Attempts(3), retry.Delay(time.Second), retry.DelayType(retry.BackOffDelay)) + sem := semaphore.NewWeighted(3) fileMd5 := md5.New() silceMd5 := md5.New() @@ -517,7 +520,9 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo if utils.IsCanceled(upCtx) { break } - + if err = sem.Acquire(ctx, 1); err != nil { + break + } byteData := make([]byte, sliceSize) if i == count { byteData = byteData[:lastPartSize] @@ -526,6 +531,7 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo // 读取块 silceMd5.Reset() if _, err := io.ReadFull(io.TeeReader(file, io.MultiWriter(fileMd5, silceMd5)), byteData); err != io.EOF && err != nil { + sem.Release(1) return nil, err } @@ -535,6 +541,7 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo partInfo := fmt.Sprintf("%d-%s", i, base64.StdEncoding.EncodeToString(md5Bytes)) threadG.Go(func(ctx context.Context) error { + defer sem.Release(1) uploadUrls, err := y.GetMultiUploadUrls(ctx, isFamily, initMultiUpload.Data.UploadFileID, partInfo) if err != nil { return err @@ -542,7 +549,8 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo // step.4 上传切片 uploadUrl := uploadUrls[0] - _, err = y.put(ctx, uploadUrl.RequestURL, uploadUrl.Headers, false, bytes.NewReader(byteData), isFamily) + _, err = y.put(ctx, uploadUrl.RequestURL, uploadUrl.Headers, false, + driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData)), isFamily) if err != nil { return err } @@ -794,6 +802,7 @@ func (y *Cloud189PC) OldUpload(ctx context.Context, dstDir model.Obj, file model if err != nil { return nil, err } + rateLimited := driver.NewLimitedUploadStream(ctx, io.NopCloser(tempFile)) // 创建上传会话 uploadInfo, err := y.OldUploadCreate(ctx, dstDir.GetID(), fileMd5, file.GetName(), fmt.Sprint(file.GetSize()), isFamily) @@ -820,7 +829,7 @@ func (y *Cloud189PC) OldUpload(ctx context.Context, dstDir model.Obj, file model header["Edrive-UploadFileId"] = fmt.Sprint(status.UploadFileId) } - _, err := y.put(ctx, status.FileUploadUrl, header, true, io.NopCloser(tempFile), isFamily) + _, err := y.put(ctx, status.FileUploadUrl, header, true, rateLimited, isFamily) if err, ok := err.(*RespErr); ok && err.Code != "InputStreamReadError" { return nil, err } diff --git a/drivers/alist_v3/driver.go b/drivers/alist_v3/driver.go index 679285e0..5a299ea0 100644 --- a/drivers/alist_v3/driver.go +++ b/drivers/alist_v3/driver.go @@ -3,7 +3,6 @@ package alist_v3 import ( "context" "fmt" - "github.com/alist-org/alist/v3/internal/stream" "io" "net/http" "path" @@ -183,10 +182,11 @@ func (d *AListV3) Remove(ctx context.Context, obj model.Obj) error { } func (d *AListV3) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) error { - req, err := http.NewRequestWithContext(ctx, http.MethodPut, d.Address+"/api/fs/put", &stream.ReaderUpdatingProgress{ + reader := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ Reader: s, UpdateProgress: up, }) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, d.Address+"/api/fs/put", reader) if err != nil { return err } diff --git a/drivers/aliyundrive/driver.go b/drivers/aliyundrive/driver.go index 2a977aa3..105e28b2 100644 --- a/drivers/aliyundrive/driver.go +++ b/drivers/aliyundrive/driver.go @@ -14,13 +14,12 @@ import ( "os" "time" - "github.com/alist-org/alist/v3/internal/stream" - "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/cron" "github.com/alist-org/alist/v3/pkg/utils" "github.com/go-resty/resty/v2" @@ -194,7 +193,10 @@ func (d *AliDrive) Put(ctx context.Context, dstDir model.Obj, streamer model.Fil } if d.RapidUpload { buf := bytes.NewBuffer(make([]byte, 0, 1024)) - utils.CopyWithBufferN(buf, file, 1024) + _, err := utils.CopyWithBufferN(buf, file, 1024) + if err != nil { + return err + } reqBody["pre_hash"] = utils.HashData(utils.SHA1, buf.Bytes()) if localFile != nil { if _, err := localFile.Seek(0, io.SeekStart); err != nil { @@ -286,6 +288,7 @@ func (d *AliDrive) Put(ctx context.Context, dstDir model.Obj, streamer model.Fil file.Reader = localFile } + rateLimited := driver.NewLimitedUploadStream(ctx, file) for i, partInfo := range resp.PartInfoList { if utils.IsCanceled(ctx) { return ctx.Err() @@ -294,7 +297,7 @@ func (d *AliDrive) Put(ctx context.Context, dstDir model.Obj, streamer model.Fil if d.InternalUpload { url = partInfo.InternalUploadUrl } - req, err := http.NewRequest("PUT", url, io.LimitReader(file, DEFAULT)) + req, err := http.NewRequest("PUT", url, io.LimitReader(rateLimited, DEFAULT)) if err != nil { return err } @@ -303,7 +306,7 @@ func (d *AliDrive) Put(ctx context.Context, dstDir model.Obj, streamer model.Fil if err != nil { return err } - res.Body.Close() + _ = res.Body.Close() if count > 0 { up(float64(i) * 100 / float64(count)) } diff --git a/drivers/aliyundrive_open/upload.go b/drivers/aliyundrive_open/upload.go index 653a2442..fb730de6 100644 --- a/drivers/aliyundrive_open/upload.go +++ b/drivers/aliyundrive_open/upload.go @@ -77,7 +77,7 @@ func (d *AliyundriveOpen) uploadPart(ctx context.Context, r io.Reader, partInfo if err != nil { return err } - res.Body.Close() + _ = res.Body.Close() if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusConflict { return fmt.Errorf("upload status: %d", res.StatusCode) } @@ -251,8 +251,9 @@ func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream m rd = utils.NewMultiReadable(srd) } err = retry.Do(func() error { - rd.Reset() - return d.uploadPart(ctx, rd, createResp.PartInfoList[i]) + _ = rd.Reset() + rateLimitedRd := driver.NewLimitedUploadStream(ctx, rd) + return d.uploadPart(ctx, rateLimitedRd, createResp.PartInfoList[i]) }, retry.Attempts(3), retry.DelayType(retry.BackOffDelay), diff --git a/drivers/baidu_netdisk/driver.go b/drivers/baidu_netdisk/driver.go index ad52a4b5..e0ba98fa 100644 --- a/drivers/baidu_netdisk/driver.go +++ b/drivers/baidu_netdisk/driver.go @@ -12,6 +12,8 @@ import ( "strconv" "time" + "golang.org/x/sync/semaphore" + "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" @@ -263,16 +265,21 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F retry.Attempts(3), retry.Delay(time.Second), retry.DelayType(retry.BackOffDelay)) + sem := semaphore.NewWeighted(3) for i, partseq := range precreateResp.BlockList { if utils.IsCanceled(upCtx) { break } + if err = sem.Acquire(ctx, 1); err != nil { + break + } i, partseq, offset, byteSize := i, partseq, int64(partseq)*sliceSize, sliceSize if partseq+1 == count { byteSize = lastBlockSize } threadG.Go(func(ctx context.Context) error { + defer sem.Release(1) params := map[string]string{ "method": "upload", "access_token": d.AccessToken, @@ -281,7 +288,8 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F "uploadid": precreateResp.Uploadid, "partseq": strconv.Itoa(partseq), } - err := d.uploadSlice(ctx, params, stream.GetName(), io.NewSectionReader(tempFile, offset, byteSize)) + err := d.uploadSlice(ctx, params, stream.GetName(), + driver.NewLimitedUploadStream(ctx, io.NewSectionReader(tempFile, offset, byteSize))) if err != nil { return err } diff --git a/drivers/baidu_photo/driver.go b/drivers/baidu_photo/driver.go index b584c9a3..9ee0a7ae 100644 --- a/drivers/baidu_photo/driver.go +++ b/drivers/baidu_photo/driver.go @@ -13,6 +13,8 @@ import ( "strings" "time" + "golang.org/x/sync/semaphore" + "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" @@ -314,10 +316,14 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil retry.Attempts(3), retry.Delay(time.Second), retry.DelayType(retry.BackOffDelay)) + sem := semaphore.NewWeighted(3) for i, partseq := range precreateResp.BlockList { if utils.IsCanceled(upCtx) { break } + if err = sem.Acquire(ctx, 1); err != nil { + break + } i, partseq, offset, byteSize := i, partseq, int64(partseq)*DEFAULT, DEFAULT if partseq+1 == count { @@ -325,6 +331,7 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil } threadG.Go(func(ctx context.Context) error { + defer sem.Release(1) uploadParams := map[string]string{ "method": "upload", "path": params["path"], @@ -335,7 +342,8 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil _, err = d.Post("https://c3.pcs.baidu.com/rest/2.0/pcs/superfile2", func(r *resty.Request) { r.SetContext(ctx) r.SetQueryParams(uploadParams) - r.SetFileReader("file", stream.GetName(), io.NewSectionReader(tempFile, offset, byteSize)) + r.SetFileReader("file", stream.GetName(), + driver.NewLimitedUploadStream(ctx, io.NewSectionReader(tempFile, offset, byteSize))) }, nil) if err != nil { return err diff --git a/drivers/base/client.go b/drivers/base/client.go index 8bf8f421..538c43a6 100644 --- a/drivers/base/client.go +++ b/drivers/base/client.go @@ -6,6 +6,7 @@ import ( "time" "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/net" "github.com/go-resty/resty/v2" ) @@ -26,7 +27,7 @@ func InitClient() { NoRedirectClient.SetHeader("user-agent", UserAgent) RestyClient = NewRestyClient() - HttpClient = NewHttpClient() + HttpClient = net.NewHttpClient() } func NewRestyClient() *resty.Client { @@ -38,13 +39,3 @@ func NewRestyClient() *resty.Client { SetTLSClientConfig(&tls.Config{InsecureSkipVerify: conf.Conf.TlsInsecureSkipVerify}) return client } - -func NewHttpClient() *http.Client { - return &http.Client{ - Timeout: time.Hour * 48, - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - TLSClientConfig: &tls.Config{InsecureSkipVerify: conf.Conf.TlsInsecureSkipVerify}, - }, - } -} diff --git a/drivers/chaoxing/driver.go b/drivers/chaoxing/driver.go index 9b526f8a..bf01a83b 100644 --- a/drivers/chaoxing/driver.go +++ b/drivers/chaoxing/driver.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/alist-org/alist/v3/internal/stream" "io" "mime/multipart" "net/http" @@ -249,13 +248,13 @@ func (d *ChaoXing) Put(ctx context.Context, dstDir model.Obj, file model.FileStr if err != nil { return err } - r := &stream.ReaderUpdatingProgress{ - Reader: &stream.SimpleReaderWithSize{ + r := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ + Reader: &driver.SimpleReaderWithSize{ Reader: body, Size: int64(body.Len()), }, UpdateProgress: up, - } + }) req, err := http.NewRequestWithContext(ctx, "POST", "https://pan-yz.chaoxing.com/upload", r) if err != nil { return err diff --git a/drivers/cloudreve/driver.go b/drivers/cloudreve/driver.go index 8fc117ac..73fc3fea 100644 --- a/drivers/cloudreve/driver.go +++ b/drivers/cloudreve/driver.go @@ -1,7 +1,9 @@ package cloudreve import ( + "bytes" "context" + "errors" "io" "net/http" "path" @@ -173,7 +175,7 @@ func (d *Cloudreve) Put(ctx context.Context, dstDir model.Obj, stream model.File var n int buf = make([]byte, chunkSize) n, err = io.ReadAtLeast(stream, buf, chunkSize) - if err != nil && err != io.ErrUnexpectedEOF { + if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) { if err == io.EOF { return nil } @@ -186,7 +188,7 @@ func (d *Cloudreve) Put(ctx context.Context, dstDir model.Obj, stream model.File err = d.request(http.MethodPost, "/file/upload/"+u.SessionID+"/"+strconv.Itoa(chunk), func(req *resty.Request) { req.SetHeader("Content-Type", "application/octet-stream") req.SetHeader("Content-Length", strconv.Itoa(n)) - req.SetBody(buf) + req.SetBody(driver.NewLimitedUploadStream(ctx, bytes.NewReader(buf))) }, nil) if err != nil { break diff --git a/drivers/cloudreve/util.go b/drivers/cloudreve/util.go index b5b71153..8a90a42f 100644 --- a/drivers/cloudreve/util.go +++ b/drivers/cloudreve/util.go @@ -100,7 +100,7 @@ func (d *Cloudreve) login() error { if err == nil { break } - if err != nil && err.Error() != "CAPTCHA not match." { + if err.Error() != "CAPTCHA not match." { break } } @@ -202,7 +202,8 @@ func (d *Cloudreve) upRemote(ctx context.Context, stream model.FileStreamer, u U if err != nil { return err } - req, err := http.NewRequest("POST", uploadUrl+"?chunk="+strconv.Itoa(chunk), bytes.NewBuffer(byteData)) + req, err := http.NewRequest("POST", uploadUrl+"?chunk="+strconv.Itoa(chunk), + driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) if err != nil { return err } @@ -214,7 +215,7 @@ func (d *Cloudreve) upRemote(ctx context.Context, stream model.FileStreamer, u U if err != nil { return err } - res.Body.Close() + _ = res.Body.Close() up(float64(finish) * 100 / float64(stream.GetSize())) chunk++ } @@ -241,7 +242,7 @@ func (d *Cloudreve) upOneDrive(ctx context.Context, stream model.FileStreamer, u if err != nil { return err } - req, err := http.NewRequest("PUT", uploadUrl, bytes.NewBuffer(byteData)) + req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) if err != nil { return err } @@ -256,10 +257,10 @@ func (d *Cloudreve) upOneDrive(ctx context.Context, stream model.FileStreamer, u // https://learn.microsoft.com/zh-cn/onedrive/developer/rest-api/api/driveitem_createuploadsession if res.StatusCode != 201 && res.StatusCode != 202 && res.StatusCode != 200 { data, _ := io.ReadAll(res.Body) - res.Body.Close() + _ = res.Body.Close() return errors.New(string(data)) } - res.Body.Close() + _ = res.Body.Close() up(float64(finish) * 100 / float64(stream.GetSize())) } // 上传成功发送回调请求 diff --git a/drivers/dropbox/driver.go b/drivers/dropbox/driver.go index 9b1717b0..fbaecc4a 100644 --- a/drivers/dropbox/driver.go +++ b/drivers/dropbox/driver.go @@ -191,7 +191,7 @@ func (d *Dropbox) Put(ctx context.Context, dstDir model.Obj, stream model.FileSt } url := d.contentBase + "/2/files/upload_session/append_v2" - reader := io.LimitReader(stream, PartSize) + reader := driver.NewLimitedUploadStream(ctx, io.LimitReader(stream, PartSize)) req, err := http.NewRequest(http.MethodPost, url, reader) if err != nil { log.Errorf("failed to update file when append to upload session, err: %+v", err) @@ -219,13 +219,8 @@ func (d *Dropbox) Put(ctx context.Context, dstDir model.Obj, stream model.FileSt return err } _ = res.Body.Close() - - if count > 0 { - up(float64(i+1) * 100 / float64(count)) - } - + up(float64(i+1) * 100 / float64(count)) offset += byteSize - } // 3.finish toPath := dstDir.GetPath() + "/" + stream.GetName() diff --git a/drivers/ftp/driver.go b/drivers/ftp/driver.go index b3e95f93..8f30b780 100644 --- a/drivers/ftp/driver.go +++ b/drivers/ftp/driver.go @@ -2,7 +2,6 @@ package ftp import ( "context" - "github.com/alist-org/alist/v3/internal/stream" stdpath "path" "github.com/alist-org/alist/v3/internal/driver" @@ -120,13 +119,10 @@ func (d *FTP) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, u return err } path := stdpath.Join(dstDir.GetPath(), s.GetName()) - return d.conn.Stor(encode(path, d.Encoding), &stream.ReaderWithCtx{ - Reader: &stream.ReaderUpdatingProgress{ - Reader: s, - UpdateProgress: up, - }, - Ctx: ctx, - }) + return d.conn.Stor(encode(path, d.Encoding), driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, + })) } var _ driver.Driver = (*FTP)(nil) diff --git a/drivers/github/driver.go b/drivers/github/driver.go index dee4cbbf..d1cfd9fb 100644 --- a/drivers/github/driver.go +++ b/drivers/github/driver.go @@ -16,7 +16,6 @@ import ( "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" - "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" "github.com/go-resty/resty/v2" log "github.com/sirupsen/logrus" @@ -676,13 +675,13 @@ func (d *Github) putBlob(ctx context.Context, s model.FileStreamer, up driver.Up afterContentReader := strings.NewReader(afterContent) req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("https://api.github.com/repos/%s/%s/git/blobs", d.Owner, d.Repo), - &stream.ReaderUpdatingProgress{ - Reader: &stream.SimpleReaderWithSize{ + driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ + Reader: &driver.SimpleReaderWithSize{ Reader: io.MultiReader(beforeContentReader, contentReader, afterContentReader), Size: length, }, UpdateProgress: up, - }) + })) if err != nil { return "", err } @@ -698,6 +697,7 @@ func (d *Github) putBlob(ctx context.Context, s model.FileStreamer, up driver.Up if err != nil { return "", err } + defer res.Body.Close() resBody, err := io.ReadAll(res.Body) if err != nil { return "", err diff --git a/drivers/google_drive/driver.go b/drivers/google_drive/driver.go index dccdcea9..c8afb084 100644 --- a/drivers/google_drive/driver.go +++ b/drivers/google_drive/driver.go @@ -158,7 +158,8 @@ func (d *GoogleDrive) Put(ctx context.Context, dstDir model.Obj, stream model.Fi putUrl := res.Header().Get("location") if stream.GetSize() < d.ChunkSize*1024*1024 { _, err = d.request(putUrl, http.MethodPut, func(req *resty.Request) { - req.SetHeader("Content-Length", strconv.FormatInt(stream.GetSize(), 10)).SetBody(stream) + req.SetHeader("Content-Length", strconv.FormatInt(stream.GetSize(), 10)). + SetBody(driver.NewLimitedUploadStream(ctx, stream)) }, nil) } else { err = d.chunkUpload(ctx, stream, putUrl) diff --git a/drivers/google_drive/util.go b/drivers/google_drive/util.go index 0d380112..0fe54346 100644 --- a/drivers/google_drive/util.go +++ b/drivers/google_drive/util.go @@ -11,10 +11,10 @@ import ( "strconv" "time" - "github.com/alist-org/alist/v3/pkg/http_range" - "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/http_range" "github.com/alist-org/alist/v3/pkg/utils" "github.com/go-resty/resty/v2" "github.com/golang-jwt/jwt/v4" @@ -126,8 +126,7 @@ func (d *GoogleDrive) refreshToken() error { } d.AccessToken = resp.AccessToken return nil - } - if gdsaFileErr != nil && os.IsExist(gdsaFileErr) { + } else if os.IsExist(gdsaFileErr) { return gdsaFileErr } url := "https://www.googleapis.com/oauth2/v4/token" @@ -229,6 +228,7 @@ func (d *GoogleDrive) chunkUpload(ctx context.Context, stream model.FileStreamer if err != nil { return err } + reader = driver.NewLimitedUploadStream(ctx, reader) _, err = d.request(url, http.MethodPut, func(req *resty.Request) { req.SetHeaders(map[string]string{ "Content-Length": strconv.FormatInt(chunkSize, 10), diff --git a/drivers/google_photo/driver.go b/drivers/google_photo/driver.go index b54132ef..e6f0abc6 100644 --- a/drivers/google_photo/driver.go +++ b/drivers/google_photo/driver.go @@ -124,7 +124,7 @@ func (d *GooglePhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fi } resp, err := d.request(postUrl, http.MethodPost, func(req *resty.Request) { - req.SetBody(stream).SetContext(ctx) + req.SetBody(driver.NewLimitedUploadStream(ctx, stream)).SetContext(ctx) }, nil, postHeaders) if err != nil { diff --git a/drivers/halalcloud/driver.go b/drivers/halalcloud/driver.go index d3235828..26832760 100644 --- a/drivers/halalcloud/driver.go +++ b/drivers/halalcloud/driver.go @@ -392,10 +392,11 @@ func (d *HalalCloud) put(ctx context.Context, dstDir model.Obj, fileStream model if fileStream.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { uploader.PartSize = fileStream.GetSize() / (s3manager.MaxUploadParts - 1) } + reader := driver.NewLimitedUploadStream(ctx, fileStream) _, err = uploader.UploadWithContext(ctx, &s3manager.UploadInput{ Bucket: aws.String(result.Bucket), Key: aws.String(result.Key), - Body: io.TeeReader(fileStream, driver.NewProgress(fileStream.GetSize(), up)), + Body: io.TeeReader(reader, driver.NewProgress(fileStream.GetSize(), up)), }) return nil, err diff --git a/drivers/ilanzou/driver.go b/drivers/ilanzou/driver.go index 22d1589f..697d85b1 100644 --- a/drivers/ilanzou/driver.go +++ b/drivers/ilanzou/driver.go @@ -309,13 +309,13 @@ func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, s model.FileStreame upToken := utils.Json.Get(res, "upToken").ToString() now := time.Now() key := fmt.Sprintf("disk/%d/%d/%d/%s/%016d", now.Year(), now.Month(), now.Day(), d.account, now.UnixMilli()) - reader := &stream.ReaderUpdatingProgress{ - Reader: &stream.SimpleReaderWithSize{ + reader := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ + Reader: &driver.SimpleReaderWithSize{ Reader: tempFile, Size: s.GetSize(), }, UpdateProgress: up, - } + }) var token string if s.GetSize() <= DefaultPartSize { res, err := d.upClient.R().SetContext(ctx).SetMultipartFormData(map[string]string{ diff --git a/drivers/ipfs_api/driver.go b/drivers/ipfs_api/driver.go index 61886b38..77760656 100644 --- a/drivers/ipfs_api/driver.go +++ b/drivers/ipfs_api/driver.go @@ -3,7 +3,6 @@ package ipfs import ( "context" "fmt" - "github.com/alist-org/alist/v3/internal/stream" "net/url" stdpath "path" "path/filepath" @@ -111,13 +110,10 @@ func (d *IPFS) Remove(ctx context.Context, obj model.Obj) error { func (d *IPFS) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) error { // TODO upload file, optional - _, err := d.sh.Add(&stream.ReaderWithCtx{ - Reader: &stream.ReaderUpdatingProgress{ - Reader: s, - UpdateProgress: up, - }, - Ctx: ctx, - }, ToFiles(stdpath.Join(dstDir.GetPath(), s.GetName()))) + _, err := d.sh.Add(driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, + }), ToFiles(stdpath.Join(dstDir.GetPath(), s.GetName()))) return err } diff --git a/drivers/kodbox/driver.go b/drivers/kodbox/driver.go index ff48ffb2..c536c916 100644 --- a/drivers/kodbox/driver.go +++ b/drivers/kodbox/driver.go @@ -3,9 +3,6 @@ package kodbox import ( "context" "fmt" - "github.com/alist-org/alist/v3/internal/stream" - "github.com/alist-org/alist/v3/pkg/utils" - "github.com/go-resty/resty/v2" "net/http" "path/filepath" "strings" @@ -13,6 +10,8 @@ import ( "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" ) type KodBox struct { @@ -229,10 +228,10 @@ func (d *KodBox) Remove(ctx context.Context, obj model.Obj) error { func (d *KodBox) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { var resp *CommonResp _, err := d.request(http.MethodPost, "/?explorer/upload/fileUpload", func(req *resty.Request) { - r := &stream.ReaderUpdatingProgress{ + r := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ Reader: s, UpdateProgress: up, - } + }) req.SetFileReader("file", s.GetName(), r). SetResult(&resp). SetFormData(map[string]string{ diff --git a/drivers/lanzou/driver.go b/drivers/lanzou/driver.go index 90635d16..877e72bb 100644 --- a/drivers/lanzou/driver.go +++ b/drivers/lanzou/driver.go @@ -2,7 +2,6 @@ package lanzou import ( "context" - "github.com/alist-org/alist/v3/internal/stream" "net/http" "github.com/alist-org/alist/v3/drivers/base" @@ -213,6 +212,10 @@ func (d *LanZou) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer if d.IsCookie() || d.IsAccount() { var resp RespText[[]FileOrFolder] _, err := d._post(d.BaseUrl+"/html5up.php", func(req *resty.Request) { + reader := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, + }) req.SetFormData(map[string]string{ "task": "1", "vie": "2", @@ -220,10 +223,7 @@ func (d *LanZou) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer "id": "WU_FILE_0", "name": s.GetName(), "folder_id_bb_n": dstDir.GetID(), - }).SetFileReader("upload_file", s.GetName(), &stream.ReaderUpdatingProgress{ - Reader: s, - UpdateProgress: up, - }).SetContext(ctx) + }).SetFileReader("upload_file", s.GetName(), reader).SetContext(ctx) }, &resp, true) if err != nil { return nil, err diff --git a/drivers/lark/driver.go b/drivers/lark/driver.go index d2672300..fbf7529a 100644 --- a/drivers/lark/driver.go +++ b/drivers/lark/driver.go @@ -320,7 +320,10 @@ func (c *Lark) Put(ctx context.Context, dstDir model.Obj, stream model.FileStrea Build() // 发起请求 - uploadLimit.Wait(ctx) + err := uploadLimit.Wait(ctx) + if err != nil { + return nil, err + } resp, err := c.client.Drive.File.UploadPrepare(ctx, req) if err != nil { return nil, err @@ -341,7 +344,7 @@ func (c *Lark) Put(ctx context.Context, dstDir model.Obj, stream model.FileStrea length = stream.GetSize() - int64(i*blockSize) } - reader := io.LimitReader(stream, length) + reader := driver.NewLimitedUploadStream(ctx, io.LimitReader(stream, length)) req := larkdrive.NewUploadPartFileReqBuilder(). Body(larkdrive.NewUploadPartFileReqBodyBuilder(). @@ -353,7 +356,10 @@ func (c *Lark) Put(ctx context.Context, dstDir model.Obj, stream model.FileStrea Build() // 发起请求 - uploadLimit.Wait(ctx) + err = uploadLimit.Wait(ctx) + if err != nil { + return nil, err + } resp, err := c.client.Drive.File.UploadPart(ctx, req) if err != nil { diff --git a/drivers/mediatrack/driver.go b/drivers/mediatrack/driver.go index ed53f8ee..50ef9799 100644 --- a/drivers/mediatrack/driver.go +++ b/drivers/mediatrack/driver.go @@ -5,7 +5,6 @@ import ( "crypto/md5" "encoding/hex" "fmt" - "github.com/alist-org/alist/v3/internal/stream" "io" "net/http" "strconv" @@ -195,13 +194,13 @@ func (d *MediaTrack) Put(ctx context.Context, dstDir model.Obj, file model.FileS input := &s3manager.UploadInput{ Bucket: &resp.Data.Bucket, Key: &resp.Data.Object, - Body: &stream.ReaderUpdatingProgress{ - Reader: &stream.SimpleReaderWithSize{ + Body: driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ + Reader: &driver.SimpleReaderWithSize{ Reader: tempFile, Size: file.GetSize(), }, UpdateProgress: up, - }, + }), } _, err = uploader.UploadWithContext(ctx, input) if err != nil { diff --git a/drivers/mega/driver.go b/drivers/mega/driver.go index 198c1f98..f76bfeef 100644 --- a/drivers/mega/driver.go +++ b/drivers/mega/driver.go @@ -156,6 +156,7 @@ func (d *Mega) Put(ctx context.Context, dstDir model.Obj, stream model.FileStrea return err } + reader := driver.NewLimitedUploadStream(ctx, stream) for id := 0; id < u.Chunks(); id++ { if utils.IsCanceled(ctx) { return ctx.Err() @@ -165,7 +166,7 @@ func (d *Mega) Put(ctx context.Context, dstDir model.Obj, stream model.FileStrea return err } chunk := make([]byte, chkSize) - n, err := io.ReadFull(stream, chunk) + n, err := io.ReadFull(reader, chunk) if err != nil && err != io.EOF { return err } diff --git a/drivers/misskey/driver.go b/drivers/misskey/driver.go index 29797a01..b5c753f3 100644 --- a/drivers/misskey/driver.go +++ b/drivers/misskey/driver.go @@ -64,7 +64,7 @@ func (d *Misskey) Remove(ctx context.Context, obj model.Obj) error { } func (d *Misskey) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { - return d.put(dstDir, stream, up) + return d.put(ctx, dstDir, stream, up) } //func (d *Template) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { diff --git a/drivers/misskey/util.go b/drivers/misskey/util.go index 4d5a3b4d..f8baeafa 100644 --- a/drivers/misskey/util.go +++ b/drivers/misskey/util.go @@ -1,7 +1,6 @@ package misskey import ( - "bytes" "context" "errors" "io" @@ -190,16 +189,16 @@ func (d *Misskey) remove(obj model.Obj) error { } } -func (d *Misskey) put(dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { +func (d *Misskey) put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { var file MFile - fileContent, err := io.ReadAll(stream) - if err != nil { - return nil, err - } - + reader := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ + Reader: stream, + UpdateProgress: up, + }) req := base.RestyClient.R(). - SetFileReader("file", stream.GetName(), io.NopCloser(bytes.NewReader(fileContent))). + SetContext(ctx). + SetFileReader("file", stream.GetName(), reader). SetFormData(map[string]string{ "folderId": handleFolderId(dstDir).(string), "name": stream.GetName(), @@ -207,7 +206,8 @@ func (d *Misskey) put(dstDir model.Obj, stream model.FileStreamer, up driver.Upd "isSensitive": "false", "force": "false", }). - SetResult(&file).SetAuthToken(d.AccessToken) + SetResult(&file). + SetAuthToken(d.AccessToken) resp, err := req.Post(d.Endpoint + "/api/drive/files/create") if err != nil { diff --git a/drivers/mopan/driver.go b/drivers/mopan/driver.go index 369ec83b..2cbabe46 100644 --- a/drivers/mopan/driver.go +++ b/drivers/mopan/driver.go @@ -10,6 +10,8 @@ import ( "strings" "time" + "golang.org/x/sync/semaphore" + "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/model" @@ -301,6 +303,7 @@ func (d *MoPan) Put(ctx context.Context, dstDir model.Obj, stream model.FileStre retry.Attempts(3), retry.Delay(time.Second), retry.DelayType(retry.BackOffDelay)) + sem := semaphore.NewWeighted(3) // step.3 parts, err := d.client.GetAllMultiUploadUrls(initUpdload.UploadFileID, initUpdload.PartInfos) @@ -312,6 +315,9 @@ func (d *MoPan) Put(ctx context.Context, dstDir model.Obj, stream model.FileStre if utils.IsCanceled(upCtx) { break } + if err = sem.Acquire(ctx, 1); err != nil { + break + } i, part, byteSize := i, part, initUpdload.PartSize if part.PartNumber == uploadPartData.PartTotal { byteSize = initUpdload.LastPartSize @@ -319,7 +325,9 @@ func (d *MoPan) Put(ctx context.Context, dstDir model.Obj, stream model.FileStre // step.4 threadG.Go(func(ctx context.Context) error { - req, err := part.NewRequest(ctx, io.NewSectionReader(file, int64(part.PartNumber-1)*initUpdload.PartSize, byteSize)) + defer sem.Release(1) + reader := io.NewSectionReader(file, int64(part.PartNumber-1)*initUpdload.PartSize, byteSize) + req, err := part.NewRequest(ctx, driver.NewLimitedUploadStream(ctx, reader)) if err != nil { return err } @@ -328,7 +336,7 @@ func (d *MoPan) Put(ctx context.Context, dstDir model.Obj, stream model.FileStre if err != nil { return err } - resp.Body.Close() + _ = resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("upload err,code=%d", resp.StatusCode) } diff --git a/drivers/netease_music/types.go b/drivers/netease_music/types.go index 332f75e9..12afeb7a 100644 --- a/drivers/netease_music/types.go +++ b/drivers/netease_music/types.go @@ -116,16 +116,3 @@ func (ch *Characteristic) merge(data map[string]string) map[string]interface{} { } return body } - -type InlineReadCloser struct { - io.Reader - io.Closer -} - -func (rc *InlineReadCloser) Read(p []byte) (int, error) { - return rc.Reader.Read(p) -} - -func (rc *InlineReadCloser) Close() error { - return rc.Closer.Close() -} diff --git a/drivers/netease_music/util.go b/drivers/netease_music/util.go index 25efde77..2e78be14 100644 --- a/drivers/netease_music/util.go +++ b/drivers/netease_music/util.go @@ -2,8 +2,6 @@ package netease_music import ( "context" - "github.com/alist-org/alist/v3/internal/driver" - "github.com/alist-org/alist/v3/internal/stream" "net/http" "path" "regexp" @@ -12,6 +10,7 @@ import ( "time" "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/pkg/utils" @@ -69,13 +68,10 @@ func (d *NeteaseMusic) request(url, method string, opt ReqOption) ([]byte, error opt.up = func(_ float64) {} } req.SetContentLength(true) - req.SetBody(&InlineReadCloser{ - Reader: &stream.ReaderUpdatingProgress{ - Reader: opt.stream, - UpdateProgress: opt.up, - }, - Closer: opt.stream, - }) + req.SetBody(driver.NewLimitedUploadStream(opt.ctx, &driver.ReaderUpdatingProgress{ + Reader: opt.stream, + UpdateProgress: opt.up, + })) } else { req.SetFormData(data) } diff --git a/drivers/onedrive/util.go b/drivers/onedrive/util.go index 95f92db6..9350a681 100644 --- a/drivers/onedrive/util.go +++ b/drivers/onedrive/util.go @@ -152,12 +152,8 @@ func (d *Onedrive) upSmall(ctx context.Context, dstDir model.Obj, stream model.F // 1. upload new file // ApiDoc: https://learn.microsoft.com/en-us/onedrive/developer/rest-api/api/driveitem_put_content?view=odsp-graph-online url := d.GetMetaUrl(false, filepath) + "/content" - data, err := io.ReadAll(stream) - if err != nil { - return err - } - _, err = d.Request(url, http.MethodPut, func(req *resty.Request) { - req.SetBody(data).SetContext(ctx) + _, err := d.Request(url, http.MethodPut, func(req *resty.Request) { + req.SetBody(driver.NewLimitedUploadStream(ctx, stream)).SetContext(ctx) }, nil) if err != nil { return fmt.Errorf("onedrive: Failed to upload new file(path=%v): %w", filepath, err) @@ -225,7 +221,7 @@ func (d *Onedrive) upBig(ctx context.Context, dstDir model.Obj, stream model.Fil if err != nil { return err } - req, err := http.NewRequest("PUT", uploadUrl, bytes.NewBuffer(byteData)) + req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) if err != nil { return err } diff --git a/drivers/onedrive_app/util.go b/drivers/onedrive_app/util.go index d036e131..a6793520 100644 --- a/drivers/onedrive_app/util.go +++ b/drivers/onedrive_app/util.go @@ -140,12 +140,8 @@ func (d *OnedriveAPP) GetFile(path string) (*File, error) { func (d *OnedriveAPP) upSmall(ctx context.Context, dstDir model.Obj, stream model.FileStreamer) error { url := d.GetMetaUrl(false, stdpath.Join(dstDir.GetPath(), stream.GetName())) + "/content" - data, err := io.ReadAll(stream) - if err != nil { - return err - } - _, err = d.Request(url, http.MethodPut, func(req *resty.Request) { - req.SetBody(data).SetContext(ctx) + _, err := d.Request(url, http.MethodPut, func(req *resty.Request) { + req.SetBody(driver.NewLimitedUploadStream(ctx, stream)).SetContext(ctx) }, nil) return err } @@ -175,7 +171,7 @@ func (d *OnedriveAPP) upBig(ctx context.Context, dstDir model.Obj, stream model. if err != nil { return err } - req, err := http.NewRequest("PUT", uploadUrl, bytes.NewBuffer(byteData)) + req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) if err != nil { return err } diff --git a/drivers/pikpak/util.go b/drivers/pikpak/util.go index eb96a42a..f2594e78 100644 --- a/drivers/pikpak/util.go +++ b/drivers/pikpak/util.go @@ -10,7 +10,6 @@ import ( "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" - "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" "github.com/aliyun/aliyun-oss-go-sdk/oss" jsoniter "github.com/json-iterator/go" @@ -430,13 +429,10 @@ func (d *PikPak) UploadByOSS(ctx context.Context, params *S3Params, s model.File return err } - err = bucket.PutObject(params.Key, &stream.ReaderWithCtx{ - Reader: &stream.ReaderUpdatingProgress{ - Reader: s, - UpdateProgress: up, - }, - Ctx: ctx, - }, OssOption(params)...) + err = bucket.PutObject(params.Key, driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, + }), OssOption(params)...) if err != nil { return err } @@ -522,11 +518,8 @@ func (d *PikPak) UploadByMultipart(ctx context.Context, params *S3Params, fileSi continue } - b := bytes.NewBuffer(buf) - if part, err = bucket.UploadPart(imur, &stream.ReaderWithCtx{ - Reader: b, - Ctx: ctx, - }, chunk.Size, chunk.Number, OssOption(params)...); err == nil { + b := driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(buf)) + if part, err = bucket.UploadPart(imur, b, chunk.Size, chunk.Number, OssOption(params)...); err == nil { break } } diff --git a/drivers/quark_uc/driver.go b/drivers/quark_uc/driver.go index 8674fbab..04757b1b 100644 --- a/drivers/quark_uc/driver.go +++ b/drivers/quark_uc/driver.go @@ -1,6 +1,7 @@ package quark import ( + "bytes" "context" "crypto/md5" "crypto/sha1" @@ -178,7 +179,7 @@ func (d *QuarkOrUC) Put(ctx context.Context, dstDir model.Obj, stream model.File } // part up partSize := pre.Metadata.PartSize - var bytes []byte + var part []byte md5s := make([]string, 0) defaultBytes := make([]byte, partSize) total := stream.GetSize() @@ -189,17 +190,18 @@ func (d *QuarkOrUC) Put(ctx context.Context, dstDir model.Obj, stream model.File return ctx.Err() } if left > int64(partSize) { - bytes = defaultBytes + part = defaultBytes } else { - bytes = make([]byte, left) + part = make([]byte, left) } - _, err := io.ReadFull(tempFile, bytes) + _, err := io.ReadFull(tempFile, part) if err != nil { return err } - left -= int64(len(bytes)) + left -= int64(len(part)) log.Debugf("left: %d", left) - m, err := d.upPart(ctx, pre, stream.GetMimetype(), partNumber, bytes) + reader := driver.NewLimitedUploadStream(ctx, bytes.NewReader(part)) + m, err := d.upPart(ctx, pre, stream.GetMimetype(), partNumber, reader) //m, err := driver.UpPart(pre, file.GetMIMEType(), partNumber, bytes, account, md5Str, sha1Str) if err != nil { return err diff --git a/drivers/quark_uc/util.go b/drivers/quark_uc/util.go index df27af67..9a3bdc1c 100644 --- a/drivers/quark_uc/util.go +++ b/drivers/quark_uc/util.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "errors" "fmt" + "io" "net/http" "strconv" "strings" @@ -119,7 +120,7 @@ func (d *QuarkOrUC) upHash(md5, sha1, taskId string) (bool, error) { return resp.Data.Finish, err } -func (d *QuarkOrUC) upPart(ctx context.Context, pre UpPreResp, mineType string, partNumber int, bytes []byte) (string, error) { +func (d *QuarkOrUC) upPart(ctx context.Context, pre UpPreResp, mineType string, partNumber int, bytes io.Reader) (string, error) { //func (driver QuarkOrUC) UpPart(pre UpPreResp, mineType string, partNumber int, bytes []byte, account *model.Account, md5Str, sha1Str string) (string, error) { timeStr := time.Now().UTC().Format(http.TimeFormat) data := base.Json{ @@ -163,6 +164,9 @@ x-oss-user-agent:aliyun-sdk-js/6.6.1 Chrome 98.0.4758.80 on Windows 10 64-bit "partNumber": strconv.Itoa(partNumber), "uploadId": pre.Data.UploadId, }).SetBody(bytes).Put(u) + if err != nil { + return "", err + } if res.StatusCode() != 200 { return "", fmt.Errorf("up status: %d, error: %s", res.StatusCode(), res.String()) } @@ -230,6 +234,9 @@ x-oss-user-agent:aliyun-sdk-js/6.6.1 Chrome 98.0.4758.80 on Windows 10 64-bit SetQueryParams(map[string]string{ "uploadId": pre.Data.UploadId, }).SetBody(body).Post(u) + if err != nil { + return err + } if res.StatusCode() != 200 { return fmt.Errorf("up status: %d, error: %s", res.StatusCode(), res.String()) } diff --git a/drivers/quqi/driver.go b/drivers/quqi/driver.go index 2ab972ca..0fa64041 100644 --- a/drivers/quqi/driver.go +++ b/drivers/quqi/driver.go @@ -12,7 +12,6 @@ import ( "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" - istream "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/pkg/utils/random" "github.com/aws/aws-sdk-go/aws" @@ -387,8 +386,8 @@ func (d *Quqi) Put(ctx context.Context, dstDir model.Obj, stream model.FileStrea } uploader := s3manager.NewUploader(s) buf := make([]byte, 1024*1024*2) - fup := &istream.ReaderUpdatingProgress{ - Reader: &istream.SimpleReaderWithSize{ + fup := &driver.ReaderUpdatingProgress{ + Reader: &driver.SimpleReaderWithSize{ Reader: f, Size: int64(len(buf)), }, @@ -402,12 +401,19 @@ func (d *Quqi) Put(ctx context.Context, dstDir model.Obj, stream model.FileStrea } return nil, err } + reader := bytes.NewReader(buf[:n]) _, err = uploader.S3.UploadPartWithContext(ctx, &s3.UploadPartInput{ UploadId: &uploadInitResp.Data.UploadID, Key: &uploadInitResp.Data.Key, Bucket: &uploadInitResp.Data.Bucket, PartNumber: aws.Int64(partNumber), - Body: bytes.NewReader(buf[:n]), + Body: struct { + *driver.RateLimitReader + io.Seeker + }{ + RateLimitReader: driver.NewLimitedUploadStream(ctx, reader), + Seeker: reader, + }, }) if err != nil { return nil, err diff --git a/drivers/s3/driver.go b/drivers/s3/driver.go index a7e924e2..b7411489 100644 --- a/drivers/s3/driver.go +++ b/drivers/s3/driver.go @@ -4,18 +4,17 @@ import ( "bytes" "context" "fmt" - "github.com/alist-org/alist/v3/server/common" "io" "net/url" stdpath "path" "strings" "time" - "github.com/alist-org/alist/v3/internal/stream" - "github.com/alist-org/alist/v3/pkg/cron" - "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/stream" + "github.com/alist-org/alist/v3/pkg/cron" + "github.com/alist-org/alist/v3/server/common" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3/s3manager" @@ -174,10 +173,10 @@ func (d *S3) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up input := &s3manager.UploadInput{ Bucket: &d.Bucket, Key: &key, - Body: &stream.ReaderUpdatingProgress{ + Body: driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ Reader: s, UpdateProgress: up, - }, + }), ContentType: &contentType, } _, err := uploader.UploadWithContext(ctx, input) diff --git a/drivers/seafile/driver.go b/drivers/seafile/driver.go index f23038d1..239f57dd 100644 --- a/drivers/seafile/driver.go +++ b/drivers/seafile/driver.go @@ -3,7 +3,6 @@ package seafile import ( "context" "fmt" - "github.com/alist-org/alist/v3/internal/stream" "net/http" "strings" "time" @@ -215,10 +214,10 @@ func (d *Seafile) Put(ctx context.Context, dstDir model.Obj, s model.FileStreame u := string(res) u = u[1 : len(u)-1] // remove quotes _, err = d.request(http.MethodPost, u, func(req *resty.Request) { - r := &stream.ReaderUpdatingProgress{ + r := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ Reader: s, UpdateProgress: up, - } + }) req.SetFileReader("file", s.GetName(), r). SetFormData(map[string]string{ "parent_dir": path, diff --git a/drivers/sftp/driver.go b/drivers/sftp/driver.go index 1f216598..7498ce39 100644 --- a/drivers/sftp/driver.go +++ b/drivers/sftp/driver.go @@ -111,7 +111,7 @@ func (d *SFTP) Put(ctx context.Context, dstDir model.Obj, stream model.FileStrea defer func() { _ = dstFile.Close() }() - err = utils.CopyWithCtx(ctx, dstFile, stream, stream.GetSize(), up) + err = utils.CopyWithCtx(ctx, dstFile, driver.NewLimitedUploadStream(ctx, stream), stream.GetSize(), up) return err } diff --git a/drivers/smb/driver.go b/drivers/smb/driver.go index 9632f24e..c292e92e 100644 --- a/drivers/smb/driver.go +++ b/drivers/smb/driver.go @@ -186,7 +186,7 @@ func (d *SMB) Put(ctx context.Context, dstDir model.Obj, stream model.FileStream _ = d.fs.Remove(fullPath) } }() - err = utils.CopyWithCtx(ctx, out, stream, stream.GetSize(), up) + err = utils.CopyWithCtx(ctx, out, driver.NewLimitedUploadStream(ctx, stream), stream.GetSize(), up) if err != nil { return err } diff --git a/drivers/teambition/driver.go b/drivers/teambition/driver.go index c75d2ac0..b37c324b 100644 --- a/drivers/teambition/driver.go +++ b/drivers/teambition/driver.go @@ -148,7 +148,7 @@ func (d *Teambition) Put(ctx context.Context, dstDir model.Obj, stream model.Fil var newFile *FileUpload if stream.GetSize() <= 20971520 { // post upload - newFile, err = d.upload(ctx, stream, token) + newFile, err = d.upload(ctx, stream, token, up) } else { // chunk upload //err = base.ErrNotImplement diff --git a/drivers/teambition/util.go b/drivers/teambition/util.go index 181cc58f..01c12cb1 100644 --- a/drivers/teambition/util.go +++ b/drivers/teambition/util.go @@ -1,6 +1,7 @@ package teambition import ( + "bytes" "context" "errors" "fmt" @@ -120,11 +121,15 @@ func (d *Teambition) getFiles(parentId string) ([]model.Obj, error) { return files, nil } -func (d *Teambition) upload(ctx context.Context, file model.FileStreamer, token string) (*FileUpload, error) { +func (d *Teambition) upload(ctx context.Context, file model.FileStreamer, token string, up driver.UpdateProgress) (*FileUpload, error) { prefix := "tcs" if d.isInternational() { prefix = "us-tcs" } + reader := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ + Reader: file, + UpdateProgress: up, + }) var newFile FileUpload res, err := base.RestyClient.R(). SetContext(ctx). @@ -134,7 +139,8 @@ func (d *Teambition) upload(ctx context.Context, file model.FileStreamer, token "type": file.GetMimetype(), "size": strconv.FormatInt(file.GetSize(), 10), "lastModifiedDate": time.Now().Format("Mon Jan 02 2006 15:04:05 GMT+0800 (中国标准时间)"), - }).SetMultipartField("file", file.GetName(), file.GetMimetype(), file). + }). + SetMultipartField("file", file.GetName(), file.GetMimetype(), reader). Post(fmt.Sprintf("https://%s.teambition.net/upload", prefix)) if err != nil { return nil, err @@ -183,10 +189,9 @@ func (d *Teambition) chunkUpload(ctx context.Context, file model.FileStreamer, t "Authorization": token, "Content-Type": "application/octet-stream", "Referer": referer, - }).SetBody(chunkData).Post(u) - if err != nil { - return nil, err - } + }). + SetBody(driver.NewLimitedUploadStream(ctx, bytes.NewReader(chunkData))). + Post(u) if err != nil { return nil, err } @@ -252,7 +257,10 @@ func (d *Teambition) newUpload(ctx context.Context, dstDir model.Obj, stream mod Key: &uploadToken.Upload.Key, ContentDisposition: &uploadToken.Upload.ContentDisposition, ContentType: &uploadToken.Upload.ContentType, - Body: stream, + Body: driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ + Reader: stream, + UpdateProgress: up, + }), } _, err = uploader.UploadWithContext(ctx, input) if err != nil { diff --git a/drivers/terabox/driver.go b/drivers/terabox/driver.go index 362de69e..82962b81 100644 --- a/drivers/terabox/driver.go +++ b/drivers/terabox/driver.go @@ -228,7 +228,7 @@ func (d *Terabox) Put(ctx context.Context, dstDir model.Obj, stream model.FileSt res, err := base.RestyClient.R(). SetContext(ctx). SetQueryParams(params). - SetFileReader("file", stream.GetName(), bytes.NewReader(byteData)). + SetFileReader("file", stream.GetName(), driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))). SetHeader("Cookie", d.Cookie). Post(u) if err != nil { diff --git a/drivers/thunder/driver.go b/drivers/thunder/driver.go index 1b7f0af6..7f41d003 100644 --- a/drivers/thunder/driver.go +++ b/drivers/thunder/driver.go @@ -3,7 +3,6 @@ package thunder import ( "context" "fmt" - "github.com/alist-org/alist/v3/internal/stream" "net/http" "strconv" "strings" @@ -383,10 +382,10 @@ func (xc *XunLeiCommon) Put(ctx context.Context, dstDir model.Obj, file model.Fi Bucket: aws.String(param.Bucket), Key: aws.String(param.Key), Expires: aws.Time(param.Expiration), - Body: &stream.ReaderUpdatingProgress{ + Body: driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ Reader: file, UpdateProgress: up, - }, + }), }) return err } diff --git a/drivers/thunder_browser/driver.go b/drivers/thunder_browser/driver.go index 96dd7e8e..7ce71f7d 100644 --- a/drivers/thunder_browser/driver.go +++ b/drivers/thunder_browser/driver.go @@ -508,7 +508,7 @@ func (xc *XunLeiBrowserCommon) Put(ctx context.Context, dstDir model.Obj, stream Bucket: aws.String(param.Bucket), Key: aws.String(param.Key), Expires: aws.Time(param.Expiration), - Body: io.TeeReader(stream, driver.NewProgress(stream.GetSize(), up)), + Body: driver.NewLimitedUploadStream(ctx, io.TeeReader(stream, driver.NewProgress(stream.GetSize(), up))), }) return err } diff --git a/drivers/thunderx/driver.go b/drivers/thunderx/driver.go index 93e07ca9..2194bdc6 100644 --- a/drivers/thunderx/driver.go +++ b/drivers/thunderx/driver.go @@ -8,7 +8,6 @@ import ( "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" - "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" "github.com/aws/aws-sdk-go/aws" @@ -414,10 +413,10 @@ func (xc *XunLeiXCommon) Put(ctx context.Context, dstDir model.Obj, file model.F Bucket: aws.String(param.Bucket), Key: aws.String(param.Key), Expires: aws.Time(param.Expiration), - Body: &stream.ReaderUpdatingProgress{ + Body: driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ Reader: file, UpdateProgress: up, - }, + }), }) return err } diff --git a/drivers/trainbit/driver.go b/drivers/trainbit/driver.go index 2b1815ed..f4f4bf3f 100644 --- a/drivers/trainbit/driver.go +++ b/drivers/trainbit/driver.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/alist-org/alist/v3/internal/stream" "io" "net/http" "net/url" @@ -59,7 +58,7 @@ func (d *Trainbit) List(ctx context.Context, dir model.Obj, args model.ListArgs) return nil, err } var jsonData any - json.Unmarshal(data, &jsonData) + err = json.Unmarshal(data, &jsonData) if err != nil { return nil, err } @@ -122,10 +121,10 @@ func (d *Trainbit) Put(ctx context.Context, dstDir model.Obj, s model.FileStream query.Add("guid", guid) query.Add("name", url.QueryEscape(local2provider(s.GetName(), false)+".")) endpoint.RawQuery = query.Encode() - progressReader := &stream.ReaderUpdatingProgress{ + progressReader := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ Reader: s, UpdateProgress: up, - } + }) req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint.String(), progressReader) if err != nil { return err diff --git a/drivers/url_tree/driver.go b/drivers/url_tree/driver.go index 569b3fba..f97d5cc5 100644 --- a/drivers/url_tree/driver.go +++ b/drivers/url_tree/driver.go @@ -3,7 +3,6 @@ package url_tree import ( "context" "errors" - "github.com/alist-org/alist/v3/internal/op" stdpath "path" "strings" "sync" @@ -11,6 +10,7 @@ import ( "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/pkg/utils" log "github.com/sirupsen/logrus" ) diff --git a/drivers/uss/driver.go b/drivers/uss/driver.go index 3c54797c..2e219050 100644 --- a/drivers/uss/driver.go +++ b/drivers/uss/driver.go @@ -126,13 +126,10 @@ func (d *USS) Remove(ctx context.Context, obj model.Obj) error { func (d *USS) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) error { return d.client.Put(&upyun.PutObjectConfig{ Path: getKey(path.Join(dstDir.GetPath(), s.GetName()), false), - Reader: &stream.ReaderWithCtx{ - Reader: &stream.ReaderUpdatingProgress{ - Reader: s, - UpdateProgress: up, - }, - Ctx: ctx, - }, + Reader: driver.NewLimitedUploadStream(ctx, &stream.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, + }), }) } diff --git a/drivers/vtencent/util.go b/drivers/vtencent/util.go index ba87f1ab..91db54b7 100644 --- a/drivers/vtencent/util.go +++ b/drivers/vtencent/util.go @@ -278,7 +278,8 @@ func (d *Vtencent) FileUpload(ctx context.Context, dstDir model.Obj, stream mode input := &s3manager.UploadInput{ Bucket: aws.String(fmt.Sprintf("%s-%d", params.StorageBucket, params.StorageAppID)), Key: ¶ms.Video.StoragePath, - Body: io.TeeReader(stream, io.MultiWriter(hash, driver.NewProgress(stream.GetSize(), up))), + Body: driver.NewLimitedUploadStream(ctx, + io.TeeReader(stream, io.MultiWriter(hash, driver.NewProgress(stream.GetSize(), up)))), } _, err = uploader.UploadWithContext(ctx, input) if err != nil { diff --git a/drivers/webdav/driver.go b/drivers/webdav/driver.go index 35240c49..45150fca 100644 --- a/drivers/webdav/driver.go +++ b/drivers/webdav/driver.go @@ -2,7 +2,6 @@ package webdav import ( "context" - "github.com/alist-org/alist/v3/internal/stream" "net/http" "os" "path" @@ -99,13 +98,11 @@ func (d *WebDav) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer r.Header.Set("Content-Type", s.GetMimetype()) r.ContentLength = s.GetSize() } - err := d.client.WriteStream(path.Join(dstDir.GetPath(), s.GetName()), &stream.ReaderWithCtx{ - Reader: &stream.ReaderUpdatingProgress{ - Reader: s, - UpdateProgress: up, - }, - Ctx: ctx, - }, 0644, callback) + reader := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, + }) + err := d.client.WriteStream(path.Join(dstDir.GetPath(), s.GetName()), reader, 0644, callback) return err } diff --git a/drivers/weiyun/driver.go b/drivers/weiyun/driver.go index 59bd7237..90793d33 100644 --- a/drivers/weiyun/driver.go +++ b/drivers/weiyun/driver.go @@ -70,7 +70,7 @@ func (d *WeiYun) Init(ctx context.Context) error { if d.client.LoginType() == 1 { d.cron = cron.NewCron(time.Minute * 5) d.cron.Do(func() { - d.client.KeepAlive() + _ = d.client.KeepAlive() }) } @@ -364,12 +364,13 @@ func (d *WeiYun) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr threadG.Go(func(ctx context.Context) error { for { channel.Len = int(math.Min(float64(stream.GetSize()-channel.Offset), float64(channel.Len))) + len64 := int64(channel.Len) upData, err := d.client.UploadFile(upCtx, channel, preData.UploadAuthData, - io.NewSectionReader(file, channel.Offset, int64(channel.Len))) + driver.NewLimitedUploadStream(ctx, io.NewSectionReader(file, channel.Offset, len64))) if err != nil { return err } - cur := total.Add(int64(channel.Len)) + cur := total.Add(len64) up(float64(cur) * 100.0 / float64(stream.GetSize())) // 上传完成 if upData.UploadState != 1 { diff --git a/drivers/wopan/driver.go b/drivers/wopan/driver.go index 86093fc1..82ec05a9 100644 --- a/drivers/wopan/driver.go +++ b/drivers/wopan/driver.go @@ -155,7 +155,7 @@ func (d *Wopan) Put(ctx context.Context, dstDir model.Obj, stream model.FileStre _, err := d.client.Upload2C(d.getSpaceType(), wopan.Upload2CFile{ Name: stream.GetName(), Size: stream.GetSize(), - Content: stream, + Content: driver.NewLimitedUploadStream(ctx, stream), ContentType: stream.GetMimetype(), }, dstDir.GetID(), d.FamilyID, wopan.Upload2COption{ OnProgress: func(current, total int64) { diff --git a/drivers/yandex_disk/driver.go b/drivers/yandex_disk/driver.go index fe858519..6e5ca05c 100644 --- a/drivers/yandex_disk/driver.go +++ b/drivers/yandex_disk/driver.go @@ -2,7 +2,6 @@ package yandex_disk import ( "context" - "github.com/alist-org/alist/v3/internal/stream" "net/http" "path" "strconv" @@ -118,10 +117,11 @@ func (d *YandexDisk) Put(ctx context.Context, dstDir model.Obj, s model.FileStre if err != nil { return err } - req, err := http.NewRequestWithContext(ctx, resp.Method, resp.Href, &stream.ReaderUpdatingProgress{ + reader := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ Reader: s, UpdateProgress: up, }) + req, err := http.NewRequestWithContext(ctx, resp.Method, resp.Href, reader) if err != nil { return err } diff --git a/go.mod b/go.mod index 2bf4ba3e..7bf8a4bb 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( github.com/u2takey/ffmpeg-go v0.5.0 github.com/upyun/go-sdk/v3 v3.0.4 github.com/winfsp/cgofuse v1.5.1-0.20230130140708-f87f5db493b5 - github.com/xhofe/tache v0.1.3 + github.com/xhofe/tache v0.1.5 github.com/xhofe/wopan-sdk-go v0.1.3 github.com/yeka/zip v0.0.0-20231116150916-03d6312748a9 github.com/zzzhr1990/go-common-entity v0.0.0-20221216044934-fd1c571e3a22 @@ -102,6 +102,7 @@ require ( github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/klauspost/pgzip v1.2.6 // indirect github.com/kr/text v0.2.0 // indirect + github.com/matoous/go-nanoid/v2 v2.1.0 // indirect github.com/nwaples/rardecode/v2 v2.0.0-beta.4.0.20241112120701-034e449c6e78 // indirect github.com/sorairolake/lzip-go v0.3.5 // indirect github.com/taruti/bytepool v0.0.0-20160310082835-5e3a9ea56543 // indirect @@ -170,7 +171,6 @@ require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgx/v5 v5.5.5 // indirect - github.com/jaevor/go-nanoid v1.3.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect @@ -240,7 +240,7 @@ require ( github.com/yusufpapurcu/wmi v1.2.4 // indirect go.etcd.io/bbolt v1.3.8 // indirect golang.org/x/arch v0.8.0 // indirect - golang.org/x/sync v0.10.0 // indirect + golang.org/x/sync v0.10.0 golang.org/x/sys v0.28.0 // indirect golang.org/x/term v0.27.0 // indirect golang.org/x/text v0.21.0 diff --git a/go.sum b/go.sum index db58dea2..a51e0c6a 100644 --- a/go.sum +++ b/go.sum @@ -337,8 +337,6 @@ github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= -github.com/jaevor/go-nanoid v1.3.0 h1:nD+iepesZS6pr3uOVf20vR9GdGgJW1HPaR46gtrxzkg= -github.com/jaevor/go-nanoid v1.3.0/go.mod h1:SI+jFaPuddYkqkVQoNGHs81navCtH388TcrH0RqFKgY= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= @@ -403,6 +401,8 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0 github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/maruel/natural v1.1.1 h1:Hja7XhhmvEFhcByqDoHz9QZbkWey+COd9xWfCfn1ioo= github.com/maruel/natural v1.1.1/go.mod h1:v+Rfd79xlw1AgVBjbO0BEQmptqb5HvL/k9GRHB7ZKEg= +github.com/matoous/go-nanoid/v2 v2.1.0 h1:P64+dmq21hhWdtvZfEAofnvJULaRR1Yib0+PnU669bE= +github.com/matoous/go-nanoid/v2 v2.1.0/go.mod h1:KlbGNQ+FhrUNIHUxZdL63t7tl4LaPkZNpUULS8H4uVM= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= @@ -596,8 +596,8 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xhofe/gsync v0.0.0-20230917091818-2111ceb38a25 h1:eDfebW/yfq9DtG9RO3KP7BT2dot2CvJGIvrB0NEoDXI= github.com/xhofe/gsync v0.0.0-20230917091818-2111ceb38a25/go.mod h1:fH4oNm5F9NfI5dLi0oIMtsLNKQOirUDbEMCIBb/7SU0= -github.com/xhofe/tache v0.1.3 h1:MipxzlljYX29E1YI/SLC7hVomVF+51iP1OUzlsuq1wE= -github.com/xhofe/tache v0.1.3/go.mod h1:iKumPFvywf30FRpAHHCt64G0JHLMzT0K+wyGedHsmTQ= +github.com/xhofe/tache v0.1.5 h1:ezDcgim7tj7KNMXliQsmf8BJQbaZtitfyQA9Nt+B4WM= +github.com/xhofe/tache v0.1.5/go.mod h1:PYt6I/XUKliSg1uHlgsk6ha+le/f6PAvjUtFZAVl3a8= github.com/xhofe/wopan-sdk-go v0.1.3 h1:J58X6v+n25ewBZjb05pKOr7AWGohb+Rdll4CThGh6+A= github.com/xhofe/wopan-sdk-go v0.1.3/go.mod h1:dcY9yA28fnaoZPnXZiVTFSkcd7GnIPTpTIIlfSI5z5Q= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= diff --git a/internal/bootstrap/data/setting.go b/internal/bootstrap/data/setting.go index 5e8a2be4..de3b8af9 100644 --- a/internal/bootstrap/data/setting.go +++ b/internal/bootstrap/data/setting.go @@ -11,6 +11,7 @@ import ( "github.com/alist-org/alist/v3/pkg/utils/random" "github.com/pkg/errors" "gorm.io/gorm" + "strconv" ) var initialSettingItems []model.SettingItem @@ -191,12 +192,12 @@ func InitialSettings() []model.SettingItem { {Key: conf.LdapDefaultPermission, Value: "0", Type: conf.TypeNumber, Group: model.LDAP, Flag: model.PRIVATE}, {Key: conf.LdapLoginTips, Value: "login with ldap", Type: conf.TypeString, Group: model.LDAP, Flag: model.PUBLIC}, - //s3 settings + // s3 settings {Key: conf.S3AccessKeyId, Value: "", Type: conf.TypeString, Group: model.S3, Flag: model.PRIVATE}, {Key: conf.S3SecretAccessKey, Value: "", Type: conf.TypeString, Group: model.S3, Flag: model.PRIVATE}, {Key: conf.S3Buckets, Value: "[]", Type: conf.TypeString, Group: model.S3, Flag: model.PRIVATE}, - //ftp settings + // ftp settings {Key: conf.FTPPublicHost, Value: "127.0.0.1", Type: conf.TypeString, Group: model.FTP, Flag: model.PRIVATE}, {Key: conf.FTPPasvPortMap, Value: "", Type: conf.TypeText, Group: model.FTP, Flag: model.PRIVATE}, {Key: conf.FTPProxyUserAgent, Value: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) " + @@ -205,6 +206,18 @@ func InitialSettings() []model.SettingItem { {Key: conf.FTPImplicitTLS, Value: "false", Type: conf.TypeBool, Group: model.FTP, Flag: model.PRIVATE}, {Key: conf.FTPTLSPrivateKeyPath, Value: "", Type: conf.TypeString, Group: model.FTP, Flag: model.PRIVATE}, {Key: conf.FTPTLSPublicCertPath, Value: "", Type: conf.TypeString, Group: model.FTP, Flag: model.PRIVATE}, + + // traffic settings + {Key: conf.TaskOfflineDownloadThreadsNum, Value: strconv.Itoa(conf.Conf.Tasks.Download.Workers), Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE}, + {Key: conf.TaskOfflineDownloadTransferThreadsNum, Value: strconv.Itoa(conf.Conf.Tasks.Transfer.Workers), Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE}, + {Key: conf.TaskUploadThreadsNum, Value: strconv.Itoa(conf.Conf.Tasks.Upload.Workers), Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE}, + {Key: conf.TaskCopyThreadsNum, Value: strconv.Itoa(conf.Conf.Tasks.Copy.Workers), Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE}, + {Key: conf.TaskDecompressDownloadThreadsNum, Value: strconv.Itoa(conf.Conf.Tasks.Decompress.Workers), Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE}, + {Key: conf.TaskDecompressUploadThreadsNum, Value: strconv.Itoa(conf.Conf.Tasks.DecompressUpload.Workers), Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE}, + {Key: conf.StreamMaxClientDownloadSpeed, Value: "-1", Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE}, + {Key: conf.StreamMaxClientUploadSpeed, Value: "-1", Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE}, + {Key: conf.StreamMaxServerDownloadSpeed, Value: "-1", Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE}, + {Key: conf.StreamMaxServerUploadSpeed, Value: "-1", Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE}, } initialSettingItems = append(initialSettingItems, tool.Tools.Items()...) if flags.Dev { diff --git a/internal/bootstrap/stream_limit.go b/internal/bootstrap/stream_limit.go new file mode 100644 index 00000000..5ece71e4 --- /dev/null +++ b/internal/bootstrap/stream_limit.go @@ -0,0 +1,53 @@ +package bootstrap + +import ( + "context" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/internal/stream" + "golang.org/x/time/rate" +) + +type blockBurstLimiter struct { + *rate.Limiter +} + +func (l blockBurstLimiter) WaitN(ctx context.Context, total int) error { + for total > 0 { + n := l.Burst() + if l.Limiter.Limit() == rate.Inf || n > total { + n = total + } + err := l.Limiter.WaitN(ctx, n) + if err != nil { + return err + } + total -= n + } + return nil +} + +func streamFilterNegative(limit int) (rate.Limit, int) { + if limit < 0 { + return rate.Inf, 0 + } + return rate.Limit(limit) * 1024.0, limit * 1024 +} + +func initLimiter(limiter *stream.Limiter, s string) { + clientDownLimit, burst := streamFilterNegative(setting.GetInt(s, -1)) + *limiter = blockBurstLimiter{Limiter: rate.NewLimiter(clientDownLimit, burst)} + op.RegisterSettingChangingCallback(func() { + newLimit, newBurst := streamFilterNegative(setting.GetInt(s, -1)) + (*limiter).SetLimit(newLimit) + (*limiter).SetBurst(newBurst) + }) +} + +func InitStreamLimit() { + initLimiter(&stream.ClientDownloadLimit, conf.StreamMaxClientDownloadSpeed) + initLimiter(&stream.ClientUploadLimit, conf.StreamMaxClientUploadSpeed) + initLimiter(&stream.ServerDownloadLimit, conf.StreamMaxServerDownloadSpeed) + initLimiter(&stream.ServerUploadLimit, conf.StreamMaxServerUploadSpeed) +} diff --git a/internal/bootstrap/task.go b/internal/bootstrap/task.go index 9c30c392..c67e3029 100644 --- a/internal/bootstrap/task.go +++ b/internal/bootstrap/task.go @@ -5,17 +5,44 @@ import ( "github.com/alist-org/alist/v3/internal/db" "github.com/alist-org/alist/v3/internal/fs" "github.com/alist-org/alist/v3/internal/offline_download/tool" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/setting" "github.com/xhofe/tache" ) +func taskFilterNegative(num int) int64 { + if num < 0 { + num = 0 + } + return int64(num) +} + func InitTaskManager() { - fs.UploadTaskManager = tache.NewManager[*fs.UploadTask](tache.WithWorks(conf.Conf.Tasks.Upload.Workers), tache.WithMaxRetry(conf.Conf.Tasks.Upload.MaxRetry)) //upload will not support persist - fs.CopyTaskManager = tache.NewManager[*fs.CopyTask](tache.WithWorks(conf.Conf.Tasks.Copy.Workers), tache.WithPersistFunction(db.GetTaskDataFunc("copy", conf.Conf.Tasks.Copy.TaskPersistant), db.UpdateTaskDataFunc("copy", conf.Conf.Tasks.Copy.TaskPersistant)), tache.WithMaxRetry(conf.Conf.Tasks.Copy.MaxRetry)) - tool.DownloadTaskManager = tache.NewManager[*tool.DownloadTask](tache.WithWorks(conf.Conf.Tasks.Download.Workers), tache.WithPersistFunction(db.GetTaskDataFunc("download", conf.Conf.Tasks.Download.TaskPersistant), db.UpdateTaskDataFunc("download", conf.Conf.Tasks.Download.TaskPersistant)), tache.WithMaxRetry(conf.Conf.Tasks.Download.MaxRetry)) - tool.TransferTaskManager = tache.NewManager[*tool.TransferTask](tache.WithWorks(conf.Conf.Tasks.Transfer.Workers), tache.WithPersistFunction(db.GetTaskDataFunc("transfer", conf.Conf.Tasks.Transfer.TaskPersistant), db.UpdateTaskDataFunc("transfer", conf.Conf.Tasks.Transfer.TaskPersistant)), tache.WithMaxRetry(conf.Conf.Tasks.Transfer.MaxRetry)) + fs.UploadTaskManager = tache.NewManager[*fs.UploadTask](tache.WithWorks(setting.GetInt(conf.TaskUploadThreadsNum, conf.Conf.Tasks.Upload.Workers)), tache.WithMaxRetry(conf.Conf.Tasks.Upload.MaxRetry)) //upload will not support persist + op.RegisterSettingChangingCallback(func() { + fs.UploadTaskManager.SetWorkersNumActive(taskFilterNegative(setting.GetInt(conf.TaskUploadThreadsNum, conf.Conf.Tasks.Upload.Workers))) + }) + fs.CopyTaskManager = tache.NewManager[*fs.CopyTask](tache.WithWorks(setting.GetInt(conf.TaskCopyThreadsNum, conf.Conf.Tasks.Copy.Workers)), tache.WithPersistFunction(db.GetTaskDataFunc("copy", conf.Conf.Tasks.Copy.TaskPersistant), db.UpdateTaskDataFunc("copy", conf.Conf.Tasks.Copy.TaskPersistant)), tache.WithMaxRetry(conf.Conf.Tasks.Copy.MaxRetry)) + op.RegisterSettingChangingCallback(func() { + fs.CopyTaskManager.SetWorkersNumActive(taskFilterNegative(setting.GetInt(conf.TaskCopyThreadsNum, conf.Conf.Tasks.Copy.Workers))) + }) + tool.DownloadTaskManager = tache.NewManager[*tool.DownloadTask](tache.WithWorks(setting.GetInt(conf.TaskOfflineDownloadThreadsNum, conf.Conf.Tasks.Download.Workers)), tache.WithPersistFunction(db.GetTaskDataFunc("download", conf.Conf.Tasks.Download.TaskPersistant), db.UpdateTaskDataFunc("download", conf.Conf.Tasks.Download.TaskPersistant)), tache.WithMaxRetry(conf.Conf.Tasks.Download.MaxRetry)) + op.RegisterSettingChangingCallback(func() { + tool.DownloadTaskManager.SetWorkersNumActive(taskFilterNegative(setting.GetInt(conf.TaskOfflineDownloadThreadsNum, conf.Conf.Tasks.Download.Workers))) + }) + tool.TransferTaskManager = tache.NewManager[*tool.TransferTask](tache.WithWorks(setting.GetInt(conf.TaskOfflineDownloadTransferThreadsNum, conf.Conf.Tasks.Transfer.Workers)), tache.WithPersistFunction(db.GetTaskDataFunc("transfer", conf.Conf.Tasks.Transfer.TaskPersistant), db.UpdateTaskDataFunc("transfer", conf.Conf.Tasks.Transfer.TaskPersistant)), tache.WithMaxRetry(conf.Conf.Tasks.Transfer.MaxRetry)) + op.RegisterSettingChangingCallback(func() { + tool.TransferTaskManager.SetWorkersNumActive(taskFilterNegative(setting.GetInt(conf.TaskOfflineDownloadTransferThreadsNum, conf.Conf.Tasks.Transfer.Workers))) + }) if len(tool.TransferTaskManager.GetAll()) == 0 { //prevent offline downloaded files from being deleted CleanTempDir() } - fs.ArchiveDownloadTaskManager = tache.NewManager[*fs.ArchiveDownloadTask](tache.WithWorks(conf.Conf.Tasks.Decompress.Workers), tache.WithPersistFunction(db.GetTaskDataFunc("decompress", conf.Conf.Tasks.Decompress.TaskPersistant), db.UpdateTaskDataFunc("decompress", conf.Conf.Tasks.Decompress.TaskPersistant)), tache.WithMaxRetry(conf.Conf.Tasks.Decompress.MaxRetry)) - fs.ArchiveContentUploadTaskManager.Manager = tache.NewManager[*fs.ArchiveContentUploadTask](tache.WithWorks(conf.Conf.Tasks.DecompressUpload.Workers), tache.WithMaxRetry(conf.Conf.Tasks.DecompressUpload.MaxRetry)) //decompress upload will not support persist + fs.ArchiveDownloadTaskManager = tache.NewManager[*fs.ArchiveDownloadTask](tache.WithWorks(setting.GetInt(conf.TaskDecompressDownloadThreadsNum, conf.Conf.Tasks.Decompress.Workers)), tache.WithPersistFunction(db.GetTaskDataFunc("decompress", conf.Conf.Tasks.Decompress.TaskPersistant), db.UpdateTaskDataFunc("decompress", conf.Conf.Tasks.Decompress.TaskPersistant)), tache.WithMaxRetry(conf.Conf.Tasks.Decompress.MaxRetry)) + op.RegisterSettingChangingCallback(func() { + fs.ArchiveDownloadTaskManager.SetWorkersNumActive(taskFilterNegative(setting.GetInt(conf.TaskDecompressDownloadThreadsNum, conf.Conf.Tasks.Decompress.Workers))) + }) + fs.ArchiveContentUploadTaskManager.Manager = tache.NewManager[*fs.ArchiveContentUploadTask](tache.WithWorks(setting.GetInt(conf.TaskDecompressUploadThreadsNum, conf.Conf.Tasks.DecompressUpload.Workers)), tache.WithMaxRetry(conf.Conf.Tasks.DecompressUpload.MaxRetry)) //decompress upload will not support persist + op.RegisterSettingChangingCallback(func() { + fs.ArchiveContentUploadTaskManager.SetWorkersNumActive(taskFilterNegative(setting.GetInt(conf.TaskDecompressUploadThreadsNum, conf.Conf.Tasks.DecompressUpload.Workers))) + }) } diff --git a/internal/conf/const.go b/internal/conf/const.go index 0e534350..fa286e46 100644 --- a/internal/conf/const.go +++ b/internal/conf/const.go @@ -115,6 +115,18 @@ const ( FTPImplicitTLS = "ftp_implicit_tls" FTPTLSPrivateKeyPath = "ftp_tls_private_key_path" FTPTLSPublicCertPath = "ftp_tls_public_cert_path" + + // traffic + TaskOfflineDownloadThreadsNum = "offline_download_task_threads_num" + TaskOfflineDownloadTransferThreadsNum = "offline_download_transfer_task_threads_num" + TaskUploadThreadsNum = "upload_task_threads_num" + TaskCopyThreadsNum = "copy_task_threads_num" + TaskDecompressDownloadThreadsNum = "decompress_download_task_threads_num" + TaskDecompressUploadThreadsNum = "decompress_upload_task_threads_num" + StreamMaxClientDownloadSpeed = "max_client_download_speed" + StreamMaxClientUploadSpeed = "max_client_upload_speed" + StreamMaxServerDownloadSpeed = "max_server_download_speed" + StreamMaxServerUploadSpeed = "max_server_upload_speed" ) const ( diff --git a/internal/driver/driver.go b/internal/driver/driver.go index 292f8e6a..05f0fe24 100644 --- a/internal/driver/driver.go +++ b/internal/driver/driver.go @@ -77,6 +77,29 @@ type Remove interface { } type Put interface { + // Put a file (provided as a FileStreamer) into the driver + // Besides the most basic upload functionality, the following features also need to be implemented: + // 1. Canceling (when `<-ctx.Done()` returns), by the following methods: + // (1) Use request methods that carry context, such as the following: + // a. http.NewRequestWithContext + // b. resty.Request.SetContext + // c. s3manager.Uploader.UploadWithContext + // d. utils.CopyWithCtx + // (2) Use a `driver.ReaderWithCtx` or a `driver.NewLimitedUploadStream` + // (3) Use `utils.IsCanceled` to check if the upload has been canceled during the upload process, + // this is typically applicable to chunked uploads. + // 2. Submit upload progress (via `up`) in real-time. There are three recommended ways as follows: + // (1) Use `utils.CopyWithCtx` + // (2) Use `driver.ReaderUpdatingProgress` + // (3) Use `driver.Progress` with `io.TeeReader` + // 3. Slow down upload speed (via `stream.ServerUploadLimit`). It requires you to wrap the read stream + // in a `driver.RateLimitReader` or a `driver.RateLimitFile` after calculating the file's hash and + // before uploading the file or file chunks. Or you can directly call `driver.ServerUploadLimitWaitN` + // if your file chunks are sufficiently small (less than about 50KB). + // NOTE that the network speed may be significantly slower than the stream's read speed. Therefore, if + // you use a `errgroup.Group` to upload each chunk in parallel, you should consider using a recursive + // mutex like `semaphore.Weighted` to limit the maximum number of upload threads, preventing excessive + // memory usage caused by buffering too many file chunks awaiting upload. Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up UpdateProgress) error } @@ -113,6 +136,29 @@ type CopyResult interface { } type PutResult interface { + // Put a file (provided as a FileStreamer) into the driver and return the put obj + // Besides the most basic upload functionality, the following features also need to be implemented: + // 1. Canceling (when `<-ctx.Done()` returns), which can be supported by the following methods: + // (1) Use request methods that carry context, such as the following: + // a. http.NewRequestWithContext + // b. resty.Request.SetContext + // c. s3manager.Uploader.UploadWithContext + // d. utils.CopyWithCtx + // (2) Use a `driver.ReaderWithCtx` or `driver.NewLimitedUploadStream` + // (3) Use `utils.IsCanceled` to check if the upload has been canceled during the upload process, + // this is typically applicable to chunked uploads. + // 2. Submit upload progress (via `up`) in real-time. There are three recommended ways as follows: + // (1) Use `utils.CopyWithCtx` + // (2) Use `driver.ReaderUpdatingProgress` + // (3) Use `driver.Progress` with `io.TeeReader` + // 3. Slow down upload speed (via `stream.ServerUploadLimit`). It requires you to wrap the read stream + // in a `driver.RateLimitReader` or a `driver.RateLimitFile` after calculating the file's hash and + // before uploading the file or file chunks. Or you can directly call `driver.ServerUploadLimitWaitN` + // if your file chunks are sufficiently small (less than about 50KB). + // NOTE that the network speed may be significantly slower than the stream's read speed. Therefore, if + // you use a `errgroup.Group` to upload each chunk in parallel, you should consider using a recursive + // mutex like `semaphore.Weighted` to limit the maximum number of upload threads, preventing excessive + // memory usage caused by buffering too many file chunks awaiting upload. Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up UpdateProgress) (model.Obj, error) } @@ -159,28 +205,6 @@ type ArchiveDecompressResult interface { ArchiveDecompress(ctx context.Context, srcObj, dstDir model.Obj, args model.ArchiveDecompressArgs) ([]model.Obj, error) } -type UpdateProgress = model.UpdateProgress - -type Progress struct { - Total int64 - Done int64 - up UpdateProgress -} - -func (p *Progress) Write(b []byte) (n int, err error) { - n = len(b) - p.Done += int64(n) - p.up(float64(p.Done) / float64(p.Total) * 100) - return -} - -func NewProgress(total int64, up UpdateProgress) *Progress { - return &Progress{ - Total: total, - up: up, - } -} - type Reference interface { InitReference(storage Driver) error } diff --git a/internal/driver/utils.go b/internal/driver/utils.go new file mode 100644 index 00000000..2af850ec --- /dev/null +++ b/internal/driver/utils.go @@ -0,0 +1,62 @@ +package driver + +import ( + "context" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/stream" + "io" +) + +type UpdateProgress = model.UpdateProgress + +type Progress struct { + Total int64 + Done int64 + up UpdateProgress +} + +func (p *Progress) Write(b []byte) (n int, err error) { + n = len(b) + p.Done += int64(n) + p.up(float64(p.Done) / float64(p.Total) * 100) + return +} + +func NewProgress(total int64, up UpdateProgress) *Progress { + return &Progress{ + Total: total, + up: up, + } +} + +type RateLimitReader = stream.RateLimitReader + +type RateLimitWriter = stream.RateLimitWriter + +type RateLimitFile = stream.RateLimitFile + +func NewLimitedUploadStream(ctx context.Context, r io.Reader) *RateLimitReader { + return &RateLimitReader{ + Reader: r, + Limiter: stream.ServerUploadLimit, + Ctx: ctx, + } +} + +func NewLimitedUploadFile(ctx context.Context, f model.File) *RateLimitFile { + return &RateLimitFile{ + File: f, + Limiter: stream.ServerUploadLimit, + Ctx: ctx, + } +} + +func ServerUploadLimitWaitN(ctx context.Context, n int) error { + return stream.ServerUploadLimit.WaitN(ctx, n) +} + +type ReaderWithCtx = stream.ReaderWithCtx + +type ReaderUpdatingProgress = stream.ReaderUpdatingProgress + +type SimpleReaderWithSize = stream.SimpleReaderWithSize diff --git a/internal/model/setting.go b/internal/model/setting.go index 9b60d98a..93b81fe5 100644 --- a/internal/model/setting.go +++ b/internal/model/setting.go @@ -12,6 +12,7 @@ const ( LDAP S3 FTP + TRAFFIC ) const ( diff --git a/internal/net/serve.go b/internal/net/serve.go index 6216cd21..c75e611f 100644 --- a/internal/net/serve.go +++ b/internal/net/serve.go @@ -3,6 +3,7 @@ package net import ( "compress/gzip" "context" + "crypto/tls" "fmt" "io" "mime" @@ -14,7 +15,6 @@ import ( "sync" "time" - "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/pkg/http_range" @@ -264,7 +264,7 @@ var httpClient *http.Client func HttpClient() *http.Client { once.Do(func() { - httpClient = base.NewHttpClient() + httpClient = NewHttpClient() httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { if len(via) >= 10 { return errors.New("stopped after 10 redirects") @@ -275,3 +275,13 @@ func HttpClient() *http.Client { }) return httpClient } + +func NewHttpClient() *http.Client { + return &http.Client{ + Timeout: time.Hour * 48, + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: &tls.Config{InsecureSkipVerify: conf.Conf.TlsInsecureSkipVerify}, + }, + } +} diff --git a/internal/op/setting.go b/internal/op/setting.go index 50eba3f7..36a792b0 100644 --- a/internal/op/setting.go +++ b/internal/op/setting.go @@ -26,9 +26,18 @@ var settingGroupCacheF = func(key string, item []model.SettingItem) { settingGroupCache.Set(key, item, cache.WithEx[[]model.SettingItem](time.Hour)) } +var settingChangingCallbacks = make([]func(), 0) + +func RegisterSettingChangingCallback(f func()) { + settingChangingCallbacks = append(settingChangingCallbacks, f) +} + func SettingCacheUpdate() { settingCache.Clear() settingGroupCache.Clear() + for _, cb := range settingChangingCallbacks { + cb() + } } func GetPublicSettingsMap() map[string]string { diff --git a/internal/stream/limit.go b/internal/stream/limit.go new file mode 100644 index 00000000..3b32a55f --- /dev/null +++ b/internal/stream/limit.go @@ -0,0 +1,152 @@ +package stream + +import ( + "context" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/alist-org/alist/v3/pkg/utils" + "golang.org/x/time/rate" + "io" + "time" +) + +type Limiter interface { + Limit() rate.Limit + Burst() int + TokensAt(time.Time) float64 + Tokens() float64 + Allow() bool + AllowN(time.Time, int) bool + Reserve() *rate.Reservation + ReserveN(time.Time, int) *rate.Reservation + Wait(context.Context) error + WaitN(context.Context, int) error + SetLimit(rate.Limit) + SetLimitAt(time.Time, rate.Limit) + SetBurst(int) + SetBurstAt(time.Time, int) +} + +var ( + ClientDownloadLimit Limiter + ClientUploadLimit Limiter + ServerDownloadLimit Limiter + ServerUploadLimit Limiter +) + +type RateLimitReader struct { + io.Reader + Limiter Limiter + Ctx context.Context +} + +func (r *RateLimitReader) Read(p []byte) (n int, err error) { + if r.Ctx != nil && utils.IsCanceled(r.Ctx) { + return 0, r.Ctx.Err() + } + n, err = r.Reader.Read(p) + if err != nil { + return + } + if r.Limiter != nil { + if r.Ctx == nil { + r.Ctx = context.Background() + } + err = r.Limiter.WaitN(r.Ctx, n) + } + return +} + +func (r *RateLimitReader) Close() error { + if c, ok := r.Reader.(io.Closer); ok { + return c.Close() + } + return nil +} + +type RateLimitWriter struct { + io.Writer + Limiter Limiter + Ctx context.Context +} + +func (w *RateLimitWriter) Write(p []byte) (n int, err error) { + if w.Ctx != nil && utils.IsCanceled(w.Ctx) { + return 0, w.Ctx.Err() + } + n, err = w.Writer.Write(p) + if err != nil { + return + } + if w.Limiter != nil { + if w.Ctx == nil { + w.Ctx = context.Background() + } + err = w.Limiter.WaitN(w.Ctx, n) + } + return +} + +func (w *RateLimitWriter) Close() error { + if c, ok := w.Writer.(io.Closer); ok { + return c.Close() + } + return nil +} + +type RateLimitFile struct { + model.File + Limiter Limiter + Ctx context.Context +} + +func (r *RateLimitFile) Read(p []byte) (n int, err error) { + if r.Ctx != nil && utils.IsCanceled(r.Ctx) { + return 0, r.Ctx.Err() + } + n, err = r.File.Read(p) + if err != nil { + return + } + if r.Limiter != nil { + if r.Ctx == nil { + r.Ctx = context.Background() + } + err = r.Limiter.WaitN(r.Ctx, n) + } + return +} + +func (r *RateLimitFile) ReadAt(p []byte, off int64) (n int, err error) { + if r.Ctx != nil && utils.IsCanceled(r.Ctx) { + return 0, r.Ctx.Err() + } + n, err = r.File.ReadAt(p, off) + if err != nil { + return + } + if r.Limiter != nil { + if r.Ctx == nil { + r.Ctx = context.Background() + } + err = r.Limiter.WaitN(r.Ctx, n) + } + return +} + +type RateLimitRangeReadCloser struct { + model.RangeReadCloserIF + Limiter Limiter +} + +func (rrc RateLimitRangeReadCloser) RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + rc, err := rrc.RangeReadCloserIF.RangeRead(ctx, httpRange) + if err != nil { + return nil, err + } + return &RateLimitReader{ + Reader: rc, + Limiter: rrc.Limiter, + Ctx: ctx, + }, nil +} diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 74646bfb..5eb6bdc7 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -182,14 +182,24 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error) } if ss.Link != nil { if ss.Link.MFile != nil { - ss.mFile = ss.Link.MFile - ss.Reader = ss.Link.MFile - ss.Closers.Add(ss.Link.MFile) + mFile := ss.Link.MFile + if _, ok := mFile.(*os.File); !ok { + mFile = &RateLimitFile{ + File: mFile, + Limiter: ServerDownloadLimit, + Ctx: fs.Ctx, + } + } + ss.mFile = mFile + ss.Reader = mFile + ss.Closers.Add(mFile) return &ss, nil } - if ss.Link.RangeReadCloser != nil { - ss.rangeReadCloser = ss.Link.RangeReadCloser + ss.rangeReadCloser = RateLimitRangeReadCloser{ + RangeReadCloserIF: ss.Link.RangeReadCloser, + Limiter: ServerDownloadLimit, + } ss.Add(ss.rangeReadCloser) return &ss, nil } @@ -198,6 +208,10 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error) if err != nil { return nil, err } + rrc = RateLimitRangeReadCloser{ + RangeReadCloserIF: rrc, + Limiter: ServerDownloadLimit, + } ss.rangeReadCloser = rrc ss.Add(rrc) return &ss, nil @@ -259,7 +273,7 @@ func (ss *SeekableStream) CacheFullInTempFile() (model.File, error) { if ss.tmpFile != nil { return ss.tmpFile, nil } - if ss.mFile != nil { + if _, ok := ss.mFile.(*os.File); ok { return ss.mFile, nil } tmpF, err := utils.CreateTempFile(ss, ss.GetSize()) @@ -276,7 +290,7 @@ func (ss *SeekableStream) CacheFullInTempFileAndUpdateProgress(up model.UpdatePr if ss.tmpFile != nil { return ss.tmpFile, nil } - if ss.mFile != nil { + if _, ok := ss.mFile.(*os.File); ok { return ss.mFile, nil } tmpF, err := utils.CreateTempFile(&ReaderUpdatingProgress{ @@ -293,12 +307,13 @@ func (ss *SeekableStream) CacheFullInTempFileAndUpdateProgress(up model.UpdatePr } func (f *FileStream) SetTmpFile(r *os.File) { - f.Reader = r + f.Add(r) f.tmpFile = r + f.Reader = r } type ReaderWithSize interface { - io.Reader + io.ReadCloser GetSize() int64 } @@ -311,6 +326,13 @@ func (r *SimpleReaderWithSize) GetSize() int64 { return r.Size } +func (r *SimpleReaderWithSize) Close() error { + if c, ok := r.Reader.(io.Closer); ok { + return c.Close() + } + return nil +} + type ReaderUpdatingProgress struct { Reader ReaderWithSize model.UpdateProgress @@ -324,6 +346,10 @@ func (r *ReaderUpdatingProgress) Read(p []byte) (n int, err error) { return n, err } +func (r *ReaderUpdatingProgress) Close() error { + return r.Reader.Close() +} + type SStreamReadAtSeeker interface { model.File GetRawStream() *SeekableStream @@ -534,7 +560,7 @@ func (r *RangeReadReadAtSeeker) Read(p []byte) (n int, err error) { func (r *RangeReadReadAtSeeker) Close() error { if r.headCache != nil { - r.headCache.close() + _ = r.headCache.close() } return r.ss.Close() } @@ -562,17 +588,3 @@ func (f *FileReadAtSeeker) Seek(offset int64, whence int) (int64, error) { func (f *FileReadAtSeeker) Close() error { return f.ss.Close() } - -type ReaderWithCtx struct { - io.Reader - Ctx context.Context -} - -func (r *ReaderWithCtx) Read(p []byte) (n int, err error) { - select { - case <-r.Ctx.Done(): - return 0, r.Ctx.Err() - default: - return r.Reader.Read(p) - } -} diff --git a/internal/stream/util.go b/internal/stream/util.go index 16854c38..bb5019e0 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -3,6 +3,7 @@ package stream import ( "context" "fmt" + "github.com/alist-org/alist/v3/pkg/utils" "io" "net/http" @@ -76,3 +77,22 @@ func checkContentRange(header *http.Header, offset int64) bool { } return false } + +type ReaderWithCtx struct { + io.Reader + Ctx context.Context +} + +func (r *ReaderWithCtx) Read(p []byte) (n int, err error) { + if utils.IsCanceled(r.Ctx) { + return 0, r.Ctx.Err() + } + return r.Reader.Read(p) +} + +func (r *ReaderWithCtx) Close() error { + if c, ok := r.Reader.(io.Closer); ok { + return c.Close() + } + return nil +} diff --git a/server/common/proxy.go b/server/common/proxy.go index 2d828efd..66854976 100644 --- a/server/common/proxy.go +++ b/server/common/proxy.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "net/url" + "os" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/net" @@ -23,11 +24,22 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. if contentType != "" { w.Header().Set("Content-Type", contentType) } - http.ServeContent(w, r, file.GetName(), file.ModTime(), link.MFile) + mFile := link.MFile + if _, ok := mFile.(*os.File); !ok { + mFile = &stream.RateLimitFile{ + File: mFile, + Limiter: stream.ServerDownloadLimit, + Ctx: r.Context(), + } + } + http.ServeContent(w, r, file.GetName(), file.ModTime(), mFile) return nil } else if link.RangeReadCloser != nil { attachFileName(w, file) - net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), link.RangeReadCloser) + net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), &stream.RateLimitRangeReadCloser{ + RangeReadCloserIF: link.RangeReadCloser, + Limiter: stream.ServerDownloadLimit, + }) return nil } else if link.Concurrency != 0 || link.PartSize != 0 { attachFileName(w, file) @@ -47,7 +59,10 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. rc, err := down.Download(ctx, req) return rc, err } - net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), &model.RangeReadCloser{RangeReader: rangeReader}) + net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), &stream.RateLimitRangeReadCloser{ + RangeReadCloserIF: &model.RangeReadCloser{RangeReader: rangeReader}, + Limiter: stream.ServerDownloadLimit, + }) return nil } else { //transparent proxy @@ -65,7 +80,11 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. if r.Method == http.MethodHead { return nil } - _, err = utils.CopyWithBuffer(w, res.Body) + _, err = utils.CopyWithBuffer(w, &stream.RateLimitReader{ + Reader: res.Body, + Limiter: stream.ServerDownloadLimit, + Ctx: r.Context(), + }) if err != nil { return err } diff --git a/server/ftp/fsread.go b/server/ftp/fsread.go index f7e018e0..c051a19d 100644 --- a/server/ftp/fsread.go +++ b/server/ftp/fsread.go @@ -60,7 +60,12 @@ func OpenDownload(ctx context.Context, reqPath string, offset int64) (*FileDownl } func (f *FileDownloadProxy) Read(p []byte) (n int, err error) { - return f.reader.Read(p) + n, err = f.reader.Read(p) + if err != nil { + return + } + err = stream.ClientDownloadLimit.WaitN(f.reader.GetRawStream().Ctx, n) + return } func (f *FileDownloadProxy) Write(p []byte) (n int, err error) { diff --git a/server/ftp/fsup.go b/server/ftp/fsup.go index 4d626d0e..ee38b1bf 100644 --- a/server/ftp/fsup.go +++ b/server/ftp/fsup.go @@ -59,7 +59,12 @@ func (f *FileUploadProxy) Read(p []byte) (n int, err error) { } func (f *FileUploadProxy) Write(p []byte) (n int, err error) { - return f.buffer.Write(p) + n, err = f.buffer.Write(p) + if err != nil { + return + } + err = stream.ClientUploadLimit.WaitN(f.ctx, n) + return } func (f *FileUploadProxy) Seek(offset int64, whence int) (int64, error) { @@ -96,7 +101,6 @@ func (f *FileUploadProxy) Close() error { WebPutAsTask: true, } s.SetTmpFile(f.buffer) - s.Closers.Add(f.buffer) _, err = fs.PutAsTask(f.ctx, dir, s) return err } @@ -127,7 +131,7 @@ func (f *FileUploadWithLengthProxy) Read(p []byte) (n int, err error) { return 0, errs.NotSupport } -func (f *FileUploadWithLengthProxy) Write(p []byte) (n int, err error) { +func (f *FileUploadWithLengthProxy) write(p []byte) (n int, err error) { if f.pipeWriter != nil { select { case e := <-f.errChan: @@ -174,6 +178,15 @@ func (f *FileUploadWithLengthProxy) Write(p []byte) (n int, err error) { } } +func (f *FileUploadWithLengthProxy) Write(p []byte) (n int, err error) { + n, err = f.write(p) + if err != nil { + return + } + err = stream.ClientUploadLimit.WaitN(f.ctx, n) + return +} + func (f *FileUploadWithLengthProxy) Seek(offset int64, whence int) (int64, error) { return 0, errs.NotSupport } diff --git a/server/middlewares/limit.go b/server/middlewares/limit.go index 44c079b3..2ccee950 100644 --- a/server/middlewares/limit.go +++ b/server/middlewares/limit.go @@ -1,7 +1,9 @@ package middlewares import ( + "github.com/alist-org/alist/v3/internal/stream" "github.com/gin-gonic/gin" + "io" ) func MaxAllowed(n int) gin.HandlerFunc { @@ -14,3 +16,37 @@ func MaxAllowed(n int) gin.HandlerFunc { c.Next() } } + +func UploadRateLimiter(limiter stream.Limiter) gin.HandlerFunc { + return func(c *gin.Context) { + c.Request.Body = &stream.RateLimitReader{ + Reader: c.Request.Body, + Limiter: limiter, + Ctx: c, + } + c.Next() + } +} + +type ResponseWriterWrapper struct { + gin.ResponseWriter + WrapWriter io.Writer +} + +func (w *ResponseWriterWrapper) Write(p []byte) (n int, err error) { + return w.WrapWriter.Write(p) +} + +func DownloadRateLimiter(limiter stream.Limiter) gin.HandlerFunc { + return func(c *gin.Context) { + c.Writer = &ResponseWriterWrapper{ + ResponseWriter: c.Writer, + WrapWriter: &stream.RateLimitWriter{ + Writer: c.Writer, + Limiter: limiter, + Ctx: c, + }, + } + c.Next() + } +} diff --git a/server/router.go b/server/router.go index 63bad60f..830051d8 100644 --- a/server/router.go +++ b/server/router.go @@ -4,6 +4,7 @@ import ( "github.com/alist-org/alist/v3/cmd/flags" "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/message" + "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/common" "github.com/alist-org/alist/v3/server/handles" @@ -38,13 +39,14 @@ func Init(e *gin.Engine) { WebDav(g.Group("/dav")) S3(g.Group("/s3")) - g.GET("/d/*path", middlewares.Down, handles.Down) - g.GET("/p/*path", middlewares.Down, handles.Proxy) + downloadLimiter := middlewares.DownloadRateLimiter(stream.ClientDownloadLimit) + g.GET("/d/*path", middlewares.Down, downloadLimiter, handles.Down) + g.GET("/p/*path", middlewares.Down, downloadLimiter, handles.Proxy) g.HEAD("/d/*path", middlewares.Down, handles.Down) g.HEAD("/p/*path", middlewares.Down, handles.Proxy) - g.GET("/ad/*path", middlewares.Down, handles.ArchiveDown) - g.GET("/ap/*path", middlewares.Down, handles.ArchiveProxy) - g.GET("/ae/*path", middlewares.Down, handles.ArchiveInternalExtract) + g.GET("/ad/*path", middlewares.Down, downloadLimiter, handles.ArchiveDown) + g.GET("/ap/*path", middlewares.Down, downloadLimiter, handles.ArchiveProxy) + g.GET("/ae/*path", middlewares.Down, downloadLimiter, handles.ArchiveInternalExtract) g.HEAD("/ad/*path", middlewares.Down, handles.ArchiveDown) g.HEAD("/ap/*path", middlewares.Down, handles.ArchiveProxy) g.HEAD("/ae/*path", middlewares.Down, handles.ArchiveInternalExtract) @@ -173,8 +175,9 @@ func _fs(g *gin.RouterGroup) { g.POST("/copy", handles.FsCopy) g.POST("/remove", handles.FsRemove) g.POST("/remove_empty_directory", handles.FsRemoveEmptyDirectory) - g.PUT("/put", middlewares.FsUp, handles.FsStream) - g.PUT("/form", middlewares.FsUp, handles.FsForm) + uploadLimiter := middlewares.UploadRateLimiter(stream.ClientUploadLimit) + g.PUT("/put", middlewares.FsUp, uploadLimiter, handles.FsStream) + g.PUT("/form", middlewares.FsUp, uploadLimiter, handles.FsForm) g.POST("/link", middlewares.AuthAdmin, handles.Link) // g.POST("/add_aria2", handles.AddOfflineDownload) // g.POST("/add_qbit", handles.AddQbittorrent) diff --git a/server/webdav.go b/server/webdav.go index cdfdce7d..a735e285 100644 --- a/server/webdav.go +++ b/server/webdav.go @@ -3,6 +3,8 @@ package server import ( "context" "crypto/subtle" + "github.com/alist-org/alist/v3/internal/stream" + "github.com/alist-org/alist/v3/server/middlewares" "net/http" "path" "strings" @@ -27,8 +29,10 @@ func WebDav(dav *gin.RouterGroup) { }, } dav.Use(WebDAVAuth) - dav.Any("/*path", ServeWebDAV) - dav.Any("", ServeWebDAV) + uploadLimiter := middlewares.UploadRateLimiter(stream.ClientUploadLimit) + downloadLimiter := middlewares.DownloadRateLimiter(stream.ClientDownloadLimit) + dav.Any("/*path", uploadLimiter, downloadLimiter, ServeWebDAV) + dav.Any("", uploadLimiter, downloadLimiter, ServeWebDAV) dav.Handle("PROPFIND", "/*path", ServeWebDAV) dav.Handle("PROPFIND", "", ServeWebDAV) dav.Handle("MKCOL", "/*path", ServeWebDAV)