perf: optimize IO read/write usage (#8243)

* perf: optimize IO read/write usage

* .

* Update drivers/139/driver.go

Co-authored-by: MadDogOwner <xiaoran@xrgzs.top>

---------

Co-authored-by: MadDogOwner <xiaoran@xrgzs.top>
pull/8250/merge
j2rong4cn 2025-04-12 16:55:31 +08:00 committed by GitHub
parent 3375c26c41
commit ddffacf07b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 427 additions and 341 deletions

View File

@ -405,7 +405,7 @@ func (d *Pan115) UploadByMultipart(ctx context.Context, params *driver115.Upload
if _, err = tmpF.ReadAt(buf, chunk.Offset); err != nil && !errors.Is(err, io.EOF) {
continue
}
if part, err = bucket.UploadPart(imur, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(buf)),
if part, err = bucket.UploadPart(imur, driver.NewLimitedUploadStream(ctx, bytes.NewReader(buf)),
chunk.Size, chunk.Number, driver115.OssOption(params, ossToken)...); err == nil {
break
}

View File

@ -2,11 +2,8 @@ package _123
import (
"context"
"crypto/md5"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"net/http"
"net/url"
"sync"
@ -18,6 +15,7 @@ import (
"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
@ -187,25 +185,12 @@ func (d *Pan123) Remove(ctx context.Context, obj model.Obj) error {
func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error {
etag := file.GetHash().GetHash(utils.MD5)
var err error
if len(etag) < utils.MD5.Width {
// const DEFAULT int64 = 10485760
h := md5.New()
// need to calculate md5 of the full content
tempFile, err := file.CacheFullInTempFile()
_, etag, err = stream.CacheFullInTempFileAndHash(file, utils.MD5)
if err != nil {
return err
}
defer func() {
_ = tempFile.Close()
}()
if _, err = utils.CopyWithBuffer(h, tempFile); err != nil {
return err
}
_, err = tempFile.Seek(0, io.SeekStart)
if err != nil {
return err
}
etag = hex.EncodeToString(h.Sum(nil))
}
data := base.Json{
"driveId": 0,

View File

@ -4,7 +4,6 @@ import (
"context"
"fmt"
"io"
"math"
"net/http"
"strconv"
@ -70,27 +69,33 @@ func (d *Pan123) completeS3(ctx context.Context, upReq *UploadResp, file model.F
}
func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.FileStreamer, up driver.UpdateProgress) error {
chunkSize := int64(1024 * 1024 * 16)
tmpF, err := file.CacheFullInTempFile()
if err != nil {
return err
}
// fetch s3 pre signed urls
chunkCount := int(math.Ceil(float64(file.GetSize()) / float64(chunkSize)))
size := file.GetSize()
chunkSize := min(size, 16*utils.MB)
chunkCount := int(size / chunkSize)
lastChunkSize := size % chunkSize
if lastChunkSize > 0 {
chunkCount++
} else {
lastChunkSize = chunkSize
}
// only 1 batch is allowed
isMultipart := chunkCount > 1
batchSize := 1
getS3UploadUrl := d.getS3Auth
if isMultipart {
if chunkCount > 1 {
batchSize = 10
getS3UploadUrl = d.getS3PreSignedUrls
}
limited := driver.NewLimitedUploadStream(ctx, file)
for i := 1; i <= chunkCount; i += batchSize {
if utils.IsCanceled(ctx) {
return ctx.Err()
}
start := i
end := i + batchSize
if end > chunkCount+1 {
end = chunkCount + 1
}
end := min(i+batchSize, chunkCount+1)
s3PreSignedUrls, err := getS3UploadUrl(ctx, upReq, start, end)
if err != nil {
return err
@ -102,9 +107,9 @@ func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.Fi
}
curSize := chunkSize
if j == chunkCount {
curSize = file.GetSize() - (int64(chunkCount)-1)*chunkSize
curSize = lastChunkSize
}
err = d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, j, end, io.LimitReader(limited, chunkSize), curSize, false, getS3UploadUrl)
err = d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, j, end, io.NewSectionReader(tmpF, chunkSize*int64(j-1), curSize), curSize, false, getS3UploadUrl)
if err != nil {
return err
}
@ -115,12 +120,12 @@ func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.Fi
return d.completeS3(ctx, upReq, file, chunkCount > 1)
}
func (d *Pan123) uploadS3Chunk(ctx context.Context, upReq *UploadResp, s3PreSignedUrls *S3PreSignedURLs, cur, end int, reader io.Reader, curSize int64, retry bool, getS3UploadUrl func(ctx context.Context, upReq *UploadResp, start int, end int) (*S3PreSignedURLs, error)) error {
func (d *Pan123) uploadS3Chunk(ctx context.Context, upReq *UploadResp, s3PreSignedUrls *S3PreSignedURLs, cur, end int, reader *io.SectionReader, curSize int64, retry bool, getS3UploadUrl func(ctx context.Context, upReq *UploadResp, start int, end int) (*S3PreSignedURLs, error)) error {
uploadUrl := s3PreSignedUrls.Data.PreSignedUrls[strconv.Itoa(cur)]
if uploadUrl == "" {
return fmt.Errorf("upload url is empty, s3PreSignedUrls: %+v", s3PreSignedUrls)
}
req, err := http.NewRequest("PUT", uploadUrl, reader)
req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, reader))
if err != nil {
return err
}
@ -143,6 +148,7 @@ func (d *Pan123) uploadS3Chunk(ctx context.Context, upReq *UploadResp, s3PreSign
}
s3PreSignedUrls.Data.PreSignedUrls = newS3PreSignedUrls.Data.PreSignedUrls
// retry
reader.Seek(0, io.SeekStart)
return d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, cur, end, reader, curSize, true, getS3UploadUrl)
}
if res.StatusCode != http.StatusOK {

View File

@ -2,20 +2,19 @@ package _139
import (
"context"
"encoding/base64"
"encoding/xml"
"fmt"
"io"
"net/http"
"path"
"strconv"
"strings"
"time"
"github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
streamPkg "github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/cron"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/alist-org/alist/v3/pkg/utils/random"
@ -72,28 +71,29 @@ func (d *Yun139) Init(ctx context.Context) error {
default:
return errs.NotImplement
}
if d.ref != nil {
return nil
}
decode, err := base64.StdEncoding.DecodeString(d.Authorization)
if err != nil {
return err
}
decodeStr := string(decode)
splits := strings.Split(decodeStr, ":")
if len(splits) < 2 {
return fmt.Errorf("authorization is invalid, splits < 2")
}
d.Account = splits[1]
_, err = d.post("/orchestration/personalCloud/user/v1.0/qryUserExternInfo", base.Json{
"qryUserExternInfoReq": base.Json{
"commonAccountInfo": base.Json{
"account": d.getAccount(),
"accountType": 1,
},
},
}, nil)
return err
// if d.ref != nil {
// return nil
// }
// decode, err := base64.StdEncoding.DecodeString(d.Authorization)
// if err != nil {
// return err
// }
// decodeStr := string(decode)
// splits := strings.Split(decodeStr, ":")
// if len(splits) < 2 {
// return fmt.Errorf("authorization is invalid, splits < 2")
// }
// d.Account = splits[1]
// _, err = d.post("/orchestration/personalCloud/user/v1.0/qryUserExternInfo", base.Json{
// "qryUserExternInfoReq": base.Json{
// "commonAccountInfo": base.Json{
// "account": d.getAccount(),
// "accountType": 1,
// },
// },
// }, nil)
// return err
return nil
}
func (d *Yun139) InitReference(storage driver.Driver) error {
@ -503,23 +503,15 @@ func (d *Yun139) Remove(ctx context.Context, obj model.Obj) error {
}
}
const (
_ = iota //ignore first value by assigning to blank identifier
KB = 1 << (10 * iota)
MB
GB
TB
)
func (d *Yun139) getPartSize(size int64) int64 {
if d.CustomUploadPartSize != 0 {
return d.CustomUploadPartSize
}
// 网盘对于分片数量存在上限
if size/GB > 30 {
return 512 * MB
if size/utils.GB > 30 {
return 512 * utils.MB
}
return 100 * MB
return 100 * utils.MB
}
func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
@ -527,29 +519,28 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr
case MetaPersonalNew:
var err error
fullHash := stream.GetHash().GetHash(utils.SHA256)
if len(fullHash) <= 0 {
tmpF, err := stream.CacheFullInTempFile()
if err != nil {
return err
}
fullHash, err = utils.HashFile(utils.SHA256, tmpF)
if len(fullHash) != utils.SHA256.Width {
_, fullHash, err = streamPkg.CacheFullInTempFileAndHash(stream, utils.SHA256)
if err != nil {
return err
}
}
partInfos := []PartInfo{}
var partSize = d.getPartSize(stream.GetSize())
part := (stream.GetSize() + partSize - 1) / partSize
if part == 0 {
size := stream.GetSize()
var partSize = d.getPartSize(size)
part := size / partSize
if size%partSize > 0 {
part++
} else if part == 0 {
part = 1
}
partInfos := make([]PartInfo, 0, part)
for i := int64(0); i < part; i++ {
if utils.IsCanceled(ctx) {
return ctx.Err()
}
start := i * partSize
byteSize := stream.GetSize() - start
byteSize := size - start
if byteSize > partSize {
byteSize = partSize
}
@ -577,7 +568,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr
"contentType": "application/octet-stream",
"parallelUpload": false,
"partInfos": firstPartInfos,
"size": stream.GetSize(),
"size": size,
"parentFileId": dstDir.GetID(),
"name": stream.GetName(),
"type": "file",
@ -630,7 +621,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr
}
// Progress
p := driver.NewProgress(stream.GetSize(), up)
p := driver.NewProgress(size, up)
rateLimited := driver.NewLimitedUploadStream(ctx, stream)
// 上传所有分片
@ -790,12 +781,14 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr
return fmt.Errorf("get file upload url failed with result code: %s, message: %s", resp.Data.Result.ResultCode, resp.Data.Result.ResultDesc)
}
size := stream.GetSize()
// Progress
p := driver.NewProgress(stream.GetSize(), up)
var partSize = d.getPartSize(stream.GetSize())
part := (stream.GetSize() + partSize - 1) / partSize
if part == 0 {
p := driver.NewProgress(size, up)
var partSize = d.getPartSize(size)
part := size / partSize
if size%partSize > 0 {
part++
} else if part == 0 {
part = 1
}
rateLimited := driver.NewLimitedUploadStream(ctx, stream)
@ -805,10 +798,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr
}
start := i * partSize
byteSize := stream.GetSize() - start
if byteSize > partSize {
byteSize = partSize
}
byteSize := min(size-start, partSize)
limitReader := io.LimitReader(rateLimited, byteSize)
// Update Progress
@ -820,7 +810,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "text/plain;name="+unicode(stream.GetName()))
req.Header.Set("contentSize", strconv.FormatInt(stream.GetSize(), 10))
req.Header.Set("contentSize", strconv.FormatInt(size, 10))
req.Header.Set("range", fmt.Sprintf("bytes=%d-%d", start, start+byteSize-1))
req.Header.Set("uploadtaskID", resp.Data.UploadResult.UploadTaskID)
req.Header.Set("rangeType", "0")

View File

@ -67,6 +67,7 @@ func (d *Yun139) refreshToken() error {
if len(splits) < 3 {
return fmt.Errorf("authorization is invalid, splits < 3")
}
d.Account = splits[1]
strs := strings.Split(splits[2], "|")
if len(strs) < 4 {
return fmt.Errorf("authorization is invalid, strs < 4")

View File

@ -3,16 +3,15 @@ package _189pc
import (
"bytes"
"context"
"crypto/md5"
"encoding/base64"
"encoding/hex"
"encoding/xml"
"fmt"
"io"
"math"
"net/http"
"net/http/cookiejar"
"net/url"
"os"
"regexp"
"sort"
"strconv"
@ -28,6 +27,7 @@ import (
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/internal/setting"
"github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/errgroup"
"github.com/alist-org/alist/v3/pkg/utils"
@ -473,12 +473,8 @@ func (y *Cloud189PC) refreshSession() (err error) {
// 普通上传
// 无法上传大小为0的文件
func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) {
var sliceSize = partSize(file.GetSize())
count := int(math.Ceil(float64(file.GetSize()) / float64(sliceSize)))
lastPartSize := file.GetSize() % sliceSize
if file.GetSize() > 0 && lastPartSize == 0 {
lastPartSize = sliceSize
}
size := file.GetSize()
sliceSize := partSize(size)
params := Params{
"parentFolderId": dstDir.GetID(),
@ -512,22 +508,29 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo
retry.DelayType(retry.BackOffDelay))
sem := semaphore.NewWeighted(3)
fileMd5 := md5.New()
silceMd5 := md5.New()
count := int(size / sliceSize)
lastPartSize := size % sliceSize
if lastPartSize > 0 {
count++
} else {
lastPartSize = sliceSize
}
fileMd5 := utils.MD5.NewFunc()
silceMd5 := utils.MD5.NewFunc()
silceMd5Hexs := make([]string, 0, count)
teeReader := io.TeeReader(file, io.MultiWriter(fileMd5, silceMd5))
byteSize := sliceSize
for i := 1; i <= count; i++ {
if utils.IsCanceled(upCtx) {
break
}
byteData := make([]byte, sliceSize)
if i == count {
byteData = byteData[:lastPartSize]
byteSize = lastPartSize
}
byteData := make([]byte, byteSize)
// 读取块
silceMd5.Reset()
if _, err := io.ReadFull(io.TeeReader(file, io.MultiWriter(fileMd5, silceMd5)), byteData); err != io.EOF && err != nil {
if _, err := io.ReadFull(teeReader, byteData); err != io.EOF && err != nil {
sem.Release(1)
return nil, err
}
@ -607,24 +610,43 @@ func (y *Cloud189PC) RapidUpload(ctx context.Context, dstDir model.Obj, stream m
// 快传
func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) {
tempFile, err := file.CacheFullInTempFile()
if err != nil {
return nil, err
var (
cache = file.GetFile()
tmpF *os.File
err error
)
size := file.GetSize()
if _, ok := cache.(io.ReaderAt); !ok && size > 0 {
tmpF, err = os.CreateTemp(conf.Conf.TempDir, "file-*")
if err != nil {
return nil, err
}
defer func() {
_ = tmpF.Close()
_ = os.Remove(tmpF.Name())
}()
cache = tmpF
}
var sliceSize = partSize(file.GetSize())
count := int(math.Ceil(float64(file.GetSize()) / float64(sliceSize)))
lastSliceSize := file.GetSize() % sliceSize
if file.GetSize() > 0 && lastSliceSize == 0 {
sliceSize := partSize(size)
count := int(size / sliceSize)
lastSliceSize := size % sliceSize
if lastSliceSize > 0 {
count++
} else {
lastSliceSize = sliceSize
}
//step.1 优先计算所需信息
byteSize := sliceSize
fileMd5 := md5.New()
silceMd5 := md5.New()
silceMd5Hexs := make([]string, 0, count)
fileMd5 := utils.MD5.NewFunc()
sliceMd5 := utils.MD5.NewFunc()
sliceMd5Hexs := make([]string, 0, count)
partInfos := make([]string, 0, count)
writers := []io.Writer{fileMd5, sliceMd5}
if tmpF != nil {
writers = append(writers, tmpF)
}
written := int64(0)
for i := 1; i <= count; i++ {
if utils.IsCanceled(ctx) {
return nil, ctx.Err()
@ -634,19 +656,31 @@ func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file mode
byteSize = lastSliceSize
}
silceMd5.Reset()
if _, err := utils.CopyWithBufferN(io.MultiWriter(fileMd5, silceMd5), tempFile, byteSize); err != nil && err != io.EOF {
n, err := utils.CopyWithBufferN(io.MultiWriter(writers...), file, byteSize)
written += n
if err != nil && err != io.EOF {
return nil, err
}
md5Byte := silceMd5.Sum(nil)
silceMd5Hexs = append(silceMd5Hexs, strings.ToUpper(hex.EncodeToString(md5Byte)))
md5Byte := sliceMd5.Sum(nil)
sliceMd5Hexs = append(sliceMd5Hexs, strings.ToUpper(hex.EncodeToString(md5Byte)))
partInfos = append(partInfos, fmt.Sprint(i, "-", base64.StdEncoding.EncodeToString(md5Byte)))
sliceMd5.Reset()
}
if tmpF != nil {
if size > 0 && written != size {
return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %d, expect = %d ", written, size)
}
_, err = tmpF.Seek(0, io.SeekStart)
if err != nil {
return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ")
}
}
fileMd5Hex := strings.ToUpper(hex.EncodeToString(fileMd5.Sum(nil)))
sliceMd5Hex := fileMd5Hex
if file.GetSize() > sliceSize {
sliceMd5Hex = strings.ToUpper(utils.GetMD5EncodeStr(strings.Join(silceMd5Hexs, "\n")))
if size > sliceSize {
sliceMd5Hex = strings.ToUpper(utils.GetMD5EncodeStr(strings.Join(sliceMd5Hexs, "\n")))
}
fullUrl := UPLOAD_URL
@ -712,7 +746,7 @@ func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file mode
}
// step.4 上传切片
_, err = y.put(ctx, uploadUrl.RequestURL, uploadUrl.Headers, false, io.NewSectionReader(tempFile, offset, byteSize), isFamily)
_, err = y.put(ctx, uploadUrl.RequestURL, uploadUrl.Headers, false, io.NewSectionReader(cache, offset, byteSize), isFamily)
if err != nil {
return err
}
@ -794,11 +828,7 @@ func (y *Cloud189PC) GetMultiUploadUrls(ctx context.Context, isFamily bool, uplo
// 旧版本上传,家庭云不支持覆盖
func (y *Cloud189PC) OldUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) {
tempFile, err := file.CacheFullInTempFile()
if err != nil {
return nil, err
}
fileMd5, err := utils.HashFile(utils.MD5, tempFile)
tempFile, fileMd5, err := stream.CacheFullInTempFileAndHash(file, utils.MD5)
if err != nil {
return nil, err
}

View File

@ -1,7 +1,6 @@
package aliyundrive_open
import (
"bytes"
"context"
"encoding/base64"
"fmt"
@ -15,6 +14,7 @@ import (
"github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/model"
streamPkg "github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/http_range"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/avast/retry-go"
@ -131,16 +131,19 @@ func (d *AliyundriveOpen) calProofCode(stream model.FileStreamer) (string, error
return "", err
}
length := proofRange.End - proofRange.Start
buf := bytes.NewBuffer(make([]byte, 0, length))
reader, err := stream.RangeRead(http_range.Range{Start: proofRange.Start, Length: length})
if err != nil {
return "", err
}
_, err = utils.CopyWithBufferN(buf, reader, length)
buf := make([]byte, length)
n, err := io.ReadFull(reader, buf)
if err == io.ErrUnexpectedEOF {
return "", fmt.Errorf("can't read data, expected=%d, got=%d", len(buf), n)
}
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
return base64.StdEncoding.EncodeToString(buf), nil
}
func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
@ -183,25 +186,18 @@ func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream m
_, err, e := d.requestReturnErrResp("/adrive/v1.0/openFile/create", http.MethodPost, func(req *resty.Request) {
req.SetBody(createData).SetResult(&createResp)
})
var tmpF model.File
if err != nil {
if e.Code != "PreHashMatched" || !rapidUpload {
return nil, err
}
log.Debugf("[aliyundrive_open] pre_hash matched, start rapid upload")
hi := stream.GetHash()
hash := hi.GetHash(utils.SHA1)
if len(hash) <= 0 {
tmpF, err = stream.CacheFullInTempFile()
hash := stream.GetHash().GetHash(utils.SHA1)
if len(hash) != utils.SHA1.Width {
_, hash, err = streamPkg.CacheFullInTempFileAndHash(stream, utils.SHA1)
if err != nil {
return nil, err
}
hash, err = utils.HashFile(utils.SHA1, tmpF)
if err != nil {
return nil, err
}
}
delete(createData, "pre_hash")

View File

@ -6,8 +6,8 @@ import (
"encoding/hex"
"errors"
"io"
"math"
"net/url"
"os"
stdpath "path"
"strconv"
"time"
@ -15,6 +15,7 @@ import (
"golang.org/x/sync/semaphore"
"github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/conf"
"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
@ -185,16 +186,30 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
return newObj, nil
}
tempFile, err := stream.CacheFullInTempFile()
if err != nil {
return nil, err
var (
cache = stream.GetFile()
tmpF *os.File
err error
)
if _, ok := cache.(io.ReaderAt); !ok {
tmpF, err = os.CreateTemp(conf.Conf.TempDir, "file-*")
if err != nil {
return nil, err
}
defer func() {
_ = tmpF.Close()
_ = os.Remove(tmpF.Name())
}()
cache = tmpF
}
streamSize := stream.GetSize()
sliceSize := d.getSliceSize(streamSize)
count := int(math.Max(math.Ceil(float64(streamSize)/float64(sliceSize)), 1))
count := int(streamSize / sliceSize)
lastBlockSize := streamSize % sliceSize
if streamSize > 0 && lastBlockSize == 0 {
if lastBlockSize > 0 {
count++
} else {
lastBlockSize = sliceSize
}
@ -207,6 +222,11 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
sliceMd5H := md5.New()
sliceMd5H2 := md5.New()
slicemd5H2Write := utils.LimitWriter(sliceMd5H2, SliceSize)
writers := []io.Writer{fileMd5H, sliceMd5H, slicemd5H2Write}
if tmpF != nil {
writers = append(writers, tmpF)
}
written := int64(0)
for i := 1; i <= count; i++ {
if utils.IsCanceled(ctx) {
@ -215,13 +235,23 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
if i == count {
byteSize = lastBlockSize
}
_, err := utils.CopyWithBufferN(io.MultiWriter(fileMd5H, sliceMd5H, slicemd5H2Write), tempFile, byteSize)
n, err := utils.CopyWithBufferN(io.MultiWriter(writers...), stream, byteSize)
written += n
if err != nil && err != io.EOF {
return nil, err
}
blockList = append(blockList, hex.EncodeToString(sliceMd5H.Sum(nil)))
sliceMd5H.Reset()
}
if tmpF != nil {
if written != streamSize {
return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %d, expect = %d ", written, streamSize)
}
_, err = tmpF.Seek(0, io.SeekStart)
if err != nil {
return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ")
}
}
contentMd5 := hex.EncodeToString(fileMd5H.Sum(nil))
sliceMd5 := hex.EncodeToString(sliceMd5H2.Sum(nil))
blockListStr, _ := utils.Json.MarshalToString(blockList)
@ -291,7 +321,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
"partseq": strconv.Itoa(partseq),
}
err := d.uploadSlice(ctx, params, stream.GetName(),
driver.NewLimitedUploadStream(ctx, io.NewSectionReader(tempFile, offset, byteSize)))
driver.NewLimitedUploadStream(ctx, io.NewSectionReader(cache, offset, byteSize)))
if err != nil {
return err
}

View File

@ -7,7 +7,7 @@ import (
"errors"
"fmt"
"io"
"math"
"os"
"regexp"
"strconv"
"strings"
@ -16,6 +16,7 @@ import (
"golang.org/x/sync/semaphore"
"github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/conf"
"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
@ -241,11 +242,21 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil
// TODO:
// 暂时没有找到妙传方式
// 需要获取完整文件md5,必须支持 io.Seek
tempFile, err := stream.CacheFullInTempFile()
if err != nil {
return nil, err
var (
cache = stream.GetFile()
tmpF *os.File
err error
)
if _, ok := cache.(io.ReaderAt); !ok {
tmpF, err = os.CreateTemp(conf.Conf.TempDir, "file-*")
if err != nil {
return nil, err
}
defer func() {
_ = tmpF.Close()
_ = os.Remove(tmpF.Name())
}()
cache = tmpF
}
const DEFAULT int64 = 1 << 22
@ -253,9 +264,11 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil
// 计算需要的数据
streamSize := stream.GetSize()
count := int(math.Ceil(float64(streamSize) / float64(DEFAULT)))
count := int(streamSize / DEFAULT)
lastBlockSize := streamSize % DEFAULT
if lastBlockSize == 0 {
if lastBlockSize > 0 {
count++
} else {
lastBlockSize = DEFAULT
}
@ -266,6 +279,11 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil
sliceMd5H := md5.New()
sliceMd5H2 := md5.New()
slicemd5H2Write := utils.LimitWriter(sliceMd5H2, SliceSize)
writers := []io.Writer{fileMd5H, sliceMd5H, slicemd5H2Write}
if tmpF != nil {
writers = append(writers, tmpF)
}
written := int64(0)
for i := 1; i <= count; i++ {
if utils.IsCanceled(ctx) {
return nil, ctx.Err()
@ -273,13 +291,23 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil
if i == count {
byteSize = lastBlockSize
}
_, err := utils.CopyWithBufferN(io.MultiWriter(fileMd5H, sliceMd5H, slicemd5H2Write), tempFile, byteSize)
n, err := utils.CopyWithBufferN(io.MultiWriter(writers...), stream, byteSize)
written += n
if err != nil && err != io.EOF {
return nil, err
}
sliceMD5List = append(sliceMD5List, hex.EncodeToString(sliceMd5H.Sum(nil)))
sliceMd5H.Reset()
}
if tmpF != nil {
if written != streamSize {
return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %d, expect = %d ", written, streamSize)
}
_, err = tmpF.Seek(0, io.SeekStart)
if err != nil {
return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ")
}
}
contentMd5 := hex.EncodeToString(fileMd5H.Sum(nil))
sliceMd5 := hex.EncodeToString(sliceMd5H2.Sum(nil))
blockListStr, _ := utils.Json.MarshalToString(sliceMD5List)
@ -291,7 +319,7 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil
"rtype": "1",
"ctype": "11",
"path": fmt.Sprintf("/%s", stream.GetName()),
"size": fmt.Sprint(stream.GetSize()),
"size": fmt.Sprint(streamSize),
"slice-md5": sliceMd5,
"content-md5": contentMd5,
"block_list": blockListStr,
@ -343,7 +371,7 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil
r.SetContext(ctx)
r.SetQueryParams(uploadParams)
r.SetFileReader("file", stream.GetName(),
driver.NewLimitedUploadStream(ctx, io.NewSectionReader(tempFile, offset, byteSize)))
driver.NewLimitedUploadStream(ctx, io.NewSectionReader(cache, offset, byteSize)))
}, nil)
if err != nil {
return err

View File

@ -204,7 +204,7 @@ func (d *Cloudreve) upLocal(ctx context.Context, stream model.FileStreamer, u Up
req.SetContentLength(true)
req.SetHeader("Content-Length", strconv.FormatInt(byteSize, 10))
req.SetHeader("User-Agent", d.getUA())
req.SetBody(driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData)))
req.SetBody(driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData)))
}, nil)
if err != nil {
break
@ -239,7 +239,7 @@ func (d *Cloudreve) upRemote(ctx context.Context, stream model.FileStreamer, u U
return err
}
req, err := http.NewRequest("POST", uploadUrl+"?chunk="+strconv.Itoa(chunk),
driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData)))
driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData)))
if err != nil {
return err
}
@ -280,7 +280,7 @@ func (d *Cloudreve) upOneDrive(ctx context.Context, stream model.FileStreamer, u
if err != nil {
return err
}
req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData)))
req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData)))
if err != nil {
return err
}

View File

@ -5,7 +5,6 @@ import (
"context"
"errors"
"fmt"
"io"
"strings"
"text/template"
"time"
@ -159,7 +158,7 @@ func signCommit(m *map[string]interface{}, entity *openpgp.Entity) (string, erro
if err != nil {
return "", err
}
if _, err = io.Copy(armorWriter, &sigBuffer); err != nil {
if _, err = utils.CopyWithBuffer(armorWriter, &sigBuffer); err != nil {
return "", err
}
_ = armorWriter.Close()

View File

@ -2,7 +2,6 @@ package template
import (
"context"
"crypto/md5"
"encoding/base64"
"encoding/hex"
"fmt"
@ -17,6 +16,7 @@ import (
"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/foxxorcat/mopan-sdk-go"
"github.com/go-resty/resty/v2"
@ -273,23 +273,14 @@ func (d *ILanZou) Remove(ctx context.Context, obj model.Obj) error {
const DefaultPartSize = 1024 * 1024 * 8
func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
h := md5.New()
// need to calculate md5 of the full content
tempFile, err := s.CacheFullInTempFile()
if err != nil {
return nil, err
etag := s.GetHash().GetHash(utils.MD5)
var err error
if len(etag) != utils.MD5.Width {
_, etag, err = stream.CacheFullInTempFileAndHash(s, utils.MD5)
if err != nil {
return nil, err
}
}
defer func() {
_ = tempFile.Close()
}()
if _, err = utils.CopyWithBuffer(h, tempFile); err != nil {
return nil, err
}
_, err = tempFile.Seek(0, io.SeekStart)
if err != nil {
return nil, err
}
etag := hex.EncodeToString(h.Sum(nil))
// get upToken
res, err := d.proved("/7n/getUpToken", http.MethodPost, func(req *resty.Request) {
req.SetBody(base.Json{
@ -309,7 +300,7 @@ func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, s model.FileStreame
key := fmt.Sprintf("disk/%d/%d/%d/%s/%016d", now.Year(), now.Month(), now.Day(), d.account, now.UnixMilli())
reader := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{
Reader: &driver.SimpleReaderWithSize{
Reader: tempFile,
Reader: s,
Size: s.GetSize(),
},
UpdateProgress: up,

View File

@ -269,9 +269,6 @@ func (d *MoPan) Put(ctx context.Context, dstDir model.Obj, stream model.FileStre
if err != nil {
return nil, err
}
defer func() {
_ = file.Close()
}()
// step.1
uploadPartData, err := mopan.InitUploadPartData(ctx, mopan.UpdloadFileParam{

View File

@ -227,7 +227,6 @@ func (d *NeteaseMusic) putSongStream(ctx context.Context, stream model.FileStrea
if err != nil {
return err
}
defer tmp.Close()
u := uploader{driver: d, file: tmp}

View File

@ -220,7 +220,7 @@ func (d *Onedrive) upBig(ctx context.Context, dstDir model.Obj, stream model.Fil
if err != nil {
return err
}
req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData)))
req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData)))
if err != nil {
return err
}

View File

@ -170,7 +170,7 @@ func (d *OnedriveAPP) upBig(ctx context.Context, dstDir model.Obj, stream model.
if err != nil {
return err
}
req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData)))
req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData)))
if err != nil {
return err
}

View File

@ -7,13 +7,6 @@ import (
"crypto/sha1"
"encoding/hex"
"fmt"
"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/aliyun/aliyun-oss-go-sdk/oss"
jsoniter "github.com/json-iterator/go"
"github.com/pkg/errors"
"io"
"net/http"
"path/filepath"
@ -24,7 +17,14 @@ import (
"time"
"github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/aliyun/aliyun-oss-go-sdk/oss"
"github.com/go-resty/resty/v2"
jsoniter "github.com/json-iterator/go"
"github.com/pkg/errors"
)
var AndroidAlgorithms = []string{
@ -516,7 +516,7 @@ func (d *PikPak) UploadByMultipart(ctx context.Context, params *S3Params, fileSi
continue
}
b := driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(buf))
b := driver.NewLimitedUploadStream(ctx, bytes.NewReader(buf))
if part, err = bucket.UploadPart(imur, b, chunk.Size, chunk.Number, OssOption(params)...); err == nil {
break
}

View File

@ -3,9 +3,8 @@ package quark
import (
"bytes"
"context"
"crypto/md5"
"crypto/sha1"
"encoding/hex"
"hash"
"io"
"net/http"
"time"
@ -14,6 +13,7 @@ import (
"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
streamPkg "github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/go-resty/resty/v2"
log "github.com/sirupsen/logrus"
@ -136,33 +136,33 @@ func (d *QuarkOrUC) Remove(ctx context.Context, obj model.Obj) error {
}
func (d *QuarkOrUC) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
tempFile, err := stream.CacheFullInTempFile()
if err != nil {
return err
md5Str, sha1Str := stream.GetHash().GetHash(utils.MD5), stream.GetHash().GetHash(utils.SHA1)
var (
md5 hash.Hash
sha1 hash.Hash
)
writers := []io.Writer{}
if len(md5Str) != utils.MD5.Width {
md5 = utils.MD5.NewFunc()
writers = append(writers, md5)
}
defer func() {
_ = tempFile.Close()
}()
m := md5.New()
_, err = utils.CopyWithBuffer(m, tempFile)
if err != nil {
return err
if len(sha1Str) != utils.SHA1.Width {
sha1 = utils.SHA1.NewFunc()
writers = append(writers, sha1)
}
_, err = tempFile.Seek(0, io.SeekStart)
if err != nil {
return err
if len(writers) > 0 {
_, err := streamPkg.CacheFullInTempFileAndWriter(stream, io.MultiWriter(writers...))
if err != nil {
return err
}
if md5 != nil {
md5Str = hex.EncodeToString(md5.Sum(nil))
}
if sha1 != nil {
sha1Str = hex.EncodeToString(sha1.Sum(nil))
}
}
md5Str := hex.EncodeToString(m.Sum(nil))
s := sha1.New()
_, err = utils.CopyWithBuffer(s, tempFile)
if err != nil {
return err
}
_, err = tempFile.Seek(0, io.SeekStart)
if err != nil {
return err
}
sha1Str := hex.EncodeToString(s.Sum(nil))
// pre
pre, err := d.upPre(stream, dstDir.GetID())
if err != nil {
@ -178,27 +178,28 @@ func (d *QuarkOrUC) Put(ctx context.Context, dstDir model.Obj, stream model.File
return nil
}
// part up
partSize := pre.Metadata.PartSize
var part []byte
md5s := make([]string, 0)
defaultBytes := make([]byte, partSize)
total := stream.GetSize()
left := total
partSize := int64(pre.Metadata.PartSize)
part := make([]byte, partSize)
count := int(total / partSize)
if total%partSize > 0 {
count++
}
md5s := make([]string, 0, count)
partNumber := 1
for left > 0 {
if utils.IsCanceled(ctx) {
return ctx.Err()
}
if left > int64(partSize) {
part = defaultBytes
} else {
part = make([]byte, left)
if left < partSize {
part = part[:left]
}
_, err := io.ReadFull(tempFile, part)
n, err := io.ReadFull(stream, part)
if err != nil {
return err
}
left -= int64(len(part))
left -= int64(n)
log.Debugf("left: %d", left)
reader := driver.NewLimitedUploadStream(ctx, bytes.NewReader(part))
m, err := d.upPart(ctx, pre, stream.GetMimetype(), partNumber, reader)

View File

@ -12,6 +12,7 @@ import (
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/utils"
hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash"
"github.com/aws/aws-sdk-go/aws"
@ -333,22 +334,17 @@ func (xc *XunLeiCommon) Remove(ctx context.Context, obj model.Obj) error {
}
func (xc *XunLeiCommon) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error {
hi := file.GetHash()
gcid := hi.GetHash(hash_extend.GCID)
gcid := file.GetHash().GetHash(hash_extend.GCID)
var err error
if len(gcid) < hash_extend.GCID.Width {
tFile, err := file.CacheFullInTempFile()
if err != nil {
return err
}
gcid, err = utils.HashFile(hash_extend.GCID, tFile, file.GetSize())
_, gcid, err = stream.CacheFullInTempFileAndHash(file, hash_extend.GCID, file.GetSize())
if err != nil {
return err
}
}
var resp UploadTaskResponse
_, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) {
_, err = xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) {
r.SetContext(ctx)
r.SetBody(&base.Json{
"kind": FILE,

View File

@ -4,10 +4,15 @@ import (
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op"
streamPkg "github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/utils"
hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash"
"github.com/aws/aws-sdk-go/aws"
@ -15,9 +20,6 @@ import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/go-resty/resty/v2"
"io"
"net/http"
"strings"
)
type ThunderBrowser struct {
@ -456,15 +458,10 @@ func (xc *XunLeiBrowserCommon) Remove(ctx context.Context, obj model.Obj) error
}
func (xc *XunLeiBrowserCommon) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
hi := stream.GetHash()
gcid := hi.GetHash(hash_extend.GCID)
gcid := stream.GetHash().GetHash(hash_extend.GCID)
var err error
if len(gcid) < hash_extend.GCID.Width {
tFile, err := stream.CacheFullInTempFile()
if err != nil {
return err
}
gcid, err = utils.HashFile(hash_extend.GCID, tFile, stream.GetSize())
_, gcid, err = streamPkg.CacheFullInTempFileAndHash(stream, hash_extend.GCID, stream.GetSize())
if err != nil {
return err
}
@ -481,7 +478,7 @@ func (xc *XunLeiBrowserCommon) Put(ctx context.Context, dstDir model.Obj, stream
}
var resp UploadTaskResponse
_, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) {
_, err = xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) {
r.SetContext(ctx)
r.SetBody(&js)
}, &resp)

View File

@ -3,11 +3,15 @@ package thunderx
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/utils"
hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash"
"github.com/aws/aws-sdk-go/aws"
@ -15,8 +19,6 @@ import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/go-resty/resty/v2"
"net/http"
"strings"
)
type ThunderX struct {
@ -364,22 +366,17 @@ func (xc *XunLeiXCommon) Remove(ctx context.Context, obj model.Obj) error {
}
func (xc *XunLeiXCommon) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error {
hi := file.GetHash()
gcid := hi.GetHash(hash_extend.GCID)
gcid := file.GetHash().GetHash(hash_extend.GCID)
var err error
if len(gcid) < hash_extend.GCID.Width {
tFile, err := file.CacheFullInTempFile()
if err != nil {
return err
}
gcid, err = utils.HashFile(hash_extend.GCID, tFile, file.GetSize())
_, gcid, err = stream.CacheFullInTempFileAndHash(file, hash_extend.GCID, file.GetSize())
if err != nil {
return err
}
}
var resp UploadTaskResponse
_, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) {
_, err = xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) {
r.SetContext(ctx)
r.SetBody(&base.Json{
"kind": FILE,

View File

@ -10,6 +10,7 @@ import (
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/mholt/archives"
)
@ -73,7 +74,7 @@ func decompress(fsys fs2.FS, filePath, targetPath string, up model.UpdateProgres
return err
}
defer f.Close()
_, err = io.Copy(f, &stream.ReaderUpdatingProgress{
_, err = utils.CopyWithBuffer(f, &stream.ReaderUpdatingProgress{
Reader: &stream.SimpleReaderWithSize{
Reader: rc,
Size: stat.Size(),

View File

@ -1,14 +1,15 @@
package iso9660
import (
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/stream"
"github.com/kdomanski/iso9660"
"io"
"os"
stdpath "path"
"strings"
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/kdomanski/iso9660"
)
func getImage(ss *stream.SeekableStream) (*iso9660.Image, error) {
@ -66,7 +67,7 @@ func decompress(f *iso9660.File, path string, up model.UpdateProgress) error {
return err
}
defer file.Close()
_, err = io.Copy(file, &stream.ReaderUpdatingProgress{
_, err = utils.CopyWithBuffer(file, &stream.ReaderUpdatingProgress{
Reader: &stream.SimpleReaderWithSize{
Reader: f.Reader(),
Size: f.Size(),

View File

@ -90,9 +90,11 @@ func (t *ArchiveDownloadTask) RunWithoutPushUploadTask() (*ArchiveContentUploadT
t.SetTotalBytes(total)
t.status = "getting src object"
for _, s := range ss {
_, err = s.CacheFullInTempFileAndUpdateProgress(func(p float64) {
t.SetProgress((float64(cur) + float64(s.GetSize())*p/100.0) / float64(total))
})
if s.GetFile() == nil {
_, err = stream.CacheFullInTempFileAndUpdateProgress(s, func(p float64) {
t.SetProgress((float64(cur) + float64(s.GetSize())*p/100.0) / float64(total))
})
}
cur += s.GetSize()
if err != nil {
return nil, err

View File

@ -2,6 +2,7 @@ package model
import (
"io"
"os"
"sort"
"strings"
"time"
@ -48,7 +49,8 @@ type FileStreamer interface {
RangeRead(http_range.Range) (io.Reader, error)
//for a non-seekable Stream, if Read is called, this function won't work
CacheFullInTempFile() (File, error)
CacheFullInTempFileAndUpdateProgress(up UpdateProgress) (File, error)
SetTmpFile(r *os.File)
GetFile() File
}
type UpdateProgress func(percentage float64)

View File

@ -248,8 +248,9 @@ func (d *downloader) sendChunkTask(newConcurrency bool) error {
size: finalSize,
id: d.nextChunk,
buf: buf,
newConcurrency: newConcurrency,
}
ch.newConcurrency = newConcurrency
d.pos += finalSize
d.nextChunk++
d.chunkChannel <- ch

View File

@ -94,27 +94,17 @@ func (f *FileStream) CacheFullInTempFile() (model.File, error) {
f.Add(tmpF)
f.tmpFile = tmpF
f.Reader = tmpF
return f.tmpFile, nil
return tmpF, nil
}
func (f *FileStream) CacheFullInTempFileAndUpdateProgress(up model.UpdateProgress) (model.File, error) {
func (f *FileStream) GetFile() model.File {
if f.tmpFile != nil {
return f.tmpFile, nil
return f.tmpFile
}
if file, ok := f.Reader.(model.File); ok {
return file, nil
return file
}
tmpF, err := utils.CreateTempFile(&ReaderUpdatingProgress{
Reader: f,
UpdateProgress: up,
}, f.GetSize())
if err != nil {
return nil, err
}
f.Add(tmpF)
f.tmpFile = tmpF
f.Reader = tmpF
return f.tmpFile, nil
return nil
}
const InMemoryBufMaxSize = 10 // Megabytes
@ -127,31 +117,36 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
// 参考 internal/net/request.go
httpRange.Length = f.GetSize() - httpRange.Start
}
if f.peekBuff != nil && httpRange.Start < int64(f.peekBuff.Len()) && httpRange.Start+httpRange.Length-1 < int64(f.peekBuff.Len()) {
size := httpRange.Start + httpRange.Length
if f.peekBuff != nil && size <= int64(f.peekBuff.Len()) {
return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil
}
if f.tmpFile == nil {
if httpRange.Start == 0 && httpRange.Length <= InMemoryBufMaxSizeBytes && f.peekBuff == nil {
bufSize := utils.Min(httpRange.Length, f.GetSize())
newBuf := bytes.NewBuffer(make([]byte, 0, bufSize))
n, err := utils.CopyWithBufferN(newBuf, f.Reader, bufSize)
var cache io.ReaderAt = f.GetFile()
if cache == nil {
if size <= InMemoryBufMaxSizeBytes {
bufSize := min(size, f.GetSize())
// 使用bytes.Buffer作为io.CopyBuffer的写入对象CopyBuffer会调用Buffer.ReadFrom
// 即使被写入的数据量与Buffer.Cap一致Buffer也会扩大
buf := make([]byte, bufSize)
n, err := io.ReadFull(f.Reader, buf)
if err != nil {
return nil, err
}
if n != bufSize {
if n != int(bufSize) {
return nil, fmt.Errorf("stream RangeRead did not get all data in peek, expect =%d ,actual =%d", bufSize, n)
}
f.peekBuff = bytes.NewReader(newBuf.Bytes())
f.peekBuff = bytes.NewReader(buf)
f.Reader = io.MultiReader(f.peekBuff, f.Reader)
return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil
cache = f.peekBuff
} else {
_, err := f.CacheFullInTempFile()
var err error
cache, err = f.CacheFullInTempFile()
if err != nil {
return nil, err
}
}
}
return io.NewSectionReader(f.tmpFile, httpRange.Start, httpRange.Length), nil
return io.NewSectionReader(cache, httpRange.Start, httpRange.Length), nil
}
var _ model.FileStreamer = (*SeekableStream)(nil)
@ -176,13 +171,13 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error)
if len(fs.Mimetype) == 0 {
fs.Mimetype = utils.GetMimeType(fs.Obj.GetName())
}
ss := SeekableStream{FileStream: fs, Link: link}
ss := &SeekableStream{FileStream: fs, Link: link}
if ss.Reader != nil {
result, ok := ss.Reader.(model.File)
if ok {
ss.mFile = result
ss.Closers.Add(result)
return &ss, nil
return ss, nil
}
}
if ss.Link != nil {
@ -198,7 +193,7 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error)
ss.mFile = mFile
ss.Reader = mFile
ss.Closers.Add(mFile)
return &ss, nil
return ss, nil
}
if ss.Link.RangeReadCloser != nil {
ss.rangeReadCloser = &RateLimitRangeReadCloser{
@ -206,7 +201,7 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error)
Limiter: ServerDownloadLimit,
}
ss.Add(ss.rangeReadCloser)
return &ss, nil
return ss, nil
}
if len(ss.Link.URL) > 0 {
rrc, err := GetRangeReadCloserFromLink(ss.GetSize(), link)
@ -219,10 +214,12 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error)
}
ss.rangeReadCloser = rrc
ss.Add(rrc)
return &ss, nil
return ss, nil
}
}
if fs.Reader != nil {
return ss, nil
}
return nil, fmt.Errorf("illegal seekableStream")
}
@ -248,7 +245,7 @@ func (ss *SeekableStream) RangeRead(httpRange http_range.Range) (io.Reader, erro
}
return rc, nil
}
return nil, fmt.Errorf("can't find mFile or rangeReadCloser")
return ss.FileStream.RangeRead(httpRange)
}
//func (f *FileStream) GetReader() io.Reader {
@ -278,7 +275,7 @@ func (ss *SeekableStream) CacheFullInTempFile() (model.File, error) {
if ss.tmpFile != nil {
return ss.tmpFile, nil
}
if _, ok := ss.mFile.(*os.File); ok {
if ss.mFile != nil {
return ss.mFile, nil
}
tmpF, err := utils.CreateTempFile(ss, ss.GetSize())
@ -288,27 +285,17 @@ func (ss *SeekableStream) CacheFullInTempFile() (model.File, error) {
ss.Add(tmpF)
ss.tmpFile = tmpF
ss.Reader = tmpF
return ss.tmpFile, nil
return tmpF, nil
}
func (ss *SeekableStream) CacheFullInTempFileAndUpdateProgress(up model.UpdateProgress) (model.File, error) {
func (ss *SeekableStream) GetFile() model.File {
if ss.tmpFile != nil {
return ss.tmpFile, nil
return ss.tmpFile
}
if _, ok := ss.mFile.(*os.File); ok {
return ss.mFile, nil
if ss.mFile != nil {
return ss.mFile
}
tmpF, err := utils.CreateTempFile(&ReaderUpdatingProgress{
Reader: ss,
UpdateProgress: up,
}, ss.GetSize())
if err != nil {
return nil, err
}
ss.Add(tmpF)
ss.tmpFile = tmpF
ss.Reader = tmpF
return ss.tmpFile, nil
return nil
}
func (f *FileStream) SetTmpFile(r *os.File) {

View File

@ -2,6 +2,7 @@ package stream
import (
"context"
"encoding/hex"
"fmt"
"io"
"net/http"
@ -96,3 +97,45 @@ func (r *ReaderWithCtx) Close() error {
}
return nil
}
func CacheFullInTempFileAndUpdateProgress(stream model.FileStreamer, up model.UpdateProgress) (model.File, error) {
if cache := stream.GetFile(); cache != nil {
up(100)
return cache, nil
}
tmpF, err := utils.CreateTempFile(&ReaderUpdatingProgress{
Reader: stream,
UpdateProgress: up,
}, stream.GetSize())
if err == nil {
stream.SetTmpFile(tmpF)
}
return tmpF, err
}
func CacheFullInTempFileAndWriter(stream model.FileStreamer, w io.Writer) (model.File, error) {
if cache := stream.GetFile(); cache != nil {
_, err := cache.Seek(0, io.SeekStart)
if err == nil {
_, err = utils.CopyWithBuffer(w, cache)
if err == nil {
_, err = cache.Seek(0, io.SeekStart)
}
}
return cache, err
}
tmpF, err := utils.CreateTempFile(io.TeeReader(stream, w), stream.GetSize())
if err == nil {
stream.SetTmpFile(tmpF)
}
return tmpF, err
}
func CacheFullInTempFileAndHash(stream model.FileStreamer, hashType *utils.HashType, params ...any) (model.File, string, error) {
h := hashType.NewFunc(params...)
tmpF, err := CacheFullInTempFileAndWriter(stream, h)
if err != nil {
return nil, "", err
}
return tmpF, hex.EncodeToString(h.Sum(nil)), err
}

View File

@ -1,8 +1,6 @@
package handles
import (
"github.com/alist-org/alist/v3/internal/task"
"github.com/alist-org/alist/v3/pkg/utils"
"io"
"net/url"
stdpath "path"
@ -12,6 +10,8 @@ import (
"github.com/alist-org/alist/v3/internal/fs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/internal/task"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/alist-org/alist/v3/server/common"
"github.com/gin-gonic/gin"
)
@ -44,7 +44,7 @@ func FsStream(c *gin.Context) {
}
if !overwrite {
if res, _ := fs.Get(c, path, &fs.GetArgs{NoLog: true}); res != nil {
_, _ = io.Copy(io.Discard, c.Request.Body)
_, _ = utils.CopyWithBuffer(io.Discard, c.Request.Body)
common.ErrorStrResp(c, "file exists", 403)
return
}
@ -66,6 +66,10 @@ func FsStream(c *gin.Context) {
if sha256 := c.GetHeader("X-File-Sha256"); sha256 != "" {
h[utils.SHA256] = sha256
}
mimetype := c.GetHeader("Content-Type")
if len(mimetype) == 0 {
mimetype = utils.GetMimeType(name)
}
s := &stream.FileStream{
Obj: &model.Object{
Name: name,
@ -74,7 +78,7 @@ func FsStream(c *gin.Context) {
HashInfo: utils.NewHashInfoByMap(h),
},
Reader: c.Request.Body,
Mimetype: c.GetHeader("Content-Type"),
Mimetype: mimetype,
WebPutAsTask: asTask,
}
var t task.TaskExtensionInfo
@ -89,6 +93,9 @@ func FsStream(c *gin.Context) {
return
}
if t == nil {
if n, _ := io.ReadFull(c.Request.Body, []byte{0}); n == 1 {
_, _ = utils.CopyWithBuffer(io.Discard, c.Request.Body)
}
common.SuccessResp(c)
return
}
@ -114,7 +121,7 @@ func FsForm(c *gin.Context) {
}
if !overwrite {
if res, _ := fs.Get(c, path, &fs.GetArgs{NoLog: true}); res != nil {
_, _ = io.Copy(io.Discard, c.Request.Body)
_, _ = utils.CopyWithBuffer(io.Discard, c.Request.Body)
common.ErrorStrResp(c, "file exists", 403)
return
}
@ -150,6 +157,10 @@ func FsForm(c *gin.Context) {
if sha256 := c.GetHeader("X-File-Sha256"); sha256 != "" {
h[utils.SHA256] = sha256
}
mimetype := file.Header.Get("Content-Type")
if len(mimetype) == 0 {
mimetype = utils.GetMimeType(name)
}
s := stream.FileStream{
Obj: &model.Object{
Name: name,
@ -158,7 +169,7 @@ func FsForm(c *gin.Context) {
HashInfo: utils.NewHashInfoByMap(h),
},
Reader: f,
Mimetype: file.Header.Get("Content-Type"),
Mimetype: mimetype,
WebPutAsTask: asTask,
}
var t task.TaskExtensionInfo
@ -168,12 +179,7 @@ func FsForm(c *gin.Context) {
}{f}
t, err = fs.PutAsTask(c, dir, &s)
} else {
ss, err := stream.NewSeekableStream(s, nil)
if err != nil {
common.ErrorResp(c, err, 500)
return
}
err = fs.PutDirectly(c, dir, ss, true)
err = fs.PutDirectly(c, dir, &s, true)
}
if err != nil {
common.ErrorResp(c, err, 500)