perf: optimize IO read/write usage (#8243)

* perf: optimize IO read/write usage

* .

* Update drivers/139/driver.go

Co-authored-by: MadDogOwner <xiaoran@xrgzs.top>

---------

Co-authored-by: MadDogOwner <xiaoran@xrgzs.top>
pull/8357/head
j2rong4cn 2025-04-12 16:55:31 +08:00 committed by GitHub
parent 3375c26c41
commit ddffacf07b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 427 additions and 341 deletions

View File

@ -405,7 +405,7 @@ func (d *Pan115) UploadByMultipart(ctx context.Context, params *driver115.Upload
if _, err = tmpF.ReadAt(buf, chunk.Offset); err != nil && !errors.Is(err, io.EOF) { if _, err = tmpF.ReadAt(buf, chunk.Offset); err != nil && !errors.Is(err, io.EOF) {
continue continue
} }
if part, err = bucket.UploadPart(imur, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(buf)), if part, err = bucket.UploadPart(imur, driver.NewLimitedUploadStream(ctx, bytes.NewReader(buf)),
chunk.Size, chunk.Number, driver115.OssOption(params, ossToken)...); err == nil { chunk.Size, chunk.Number, driver115.OssOption(params, ossToken)...); err == nil {
break break
} }

View File

@ -2,11 +2,8 @@ package _123
import ( import (
"context" "context"
"crypto/md5"
"encoding/base64" "encoding/base64"
"encoding/hex"
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/url" "net/url"
"sync" "sync"
@ -18,6 +15,7 @@ import (
"github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/pkg/utils"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
@ -187,25 +185,12 @@ func (d *Pan123) Remove(ctx context.Context, obj model.Obj) error {
func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error {
etag := file.GetHash().GetHash(utils.MD5) etag := file.GetHash().GetHash(utils.MD5)
var err error
if len(etag) < utils.MD5.Width { if len(etag) < utils.MD5.Width {
// const DEFAULT int64 = 10485760 _, etag, err = stream.CacheFullInTempFileAndHash(file, utils.MD5)
h := md5.New()
// need to calculate md5 of the full content
tempFile, err := file.CacheFullInTempFile()
if err != nil { if err != nil {
return err return err
} }
defer func() {
_ = tempFile.Close()
}()
if _, err = utils.CopyWithBuffer(h, tempFile); err != nil {
return err
}
_, err = tempFile.Seek(0, io.SeekStart)
if err != nil {
return err
}
etag = hex.EncodeToString(h.Sum(nil))
} }
data := base.Json{ data := base.Json{
"driveId": 0, "driveId": 0,

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"math"
"net/http" "net/http"
"strconv" "strconv"
@ -70,27 +69,33 @@ func (d *Pan123) completeS3(ctx context.Context, upReq *UploadResp, file model.F
} }
func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.FileStreamer, up driver.UpdateProgress) error { func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.FileStreamer, up driver.UpdateProgress) error {
chunkSize := int64(1024 * 1024 * 16) tmpF, err := file.CacheFullInTempFile()
if err != nil {
return err
}
// fetch s3 pre signed urls // fetch s3 pre signed urls
chunkCount := int(math.Ceil(float64(file.GetSize()) / float64(chunkSize))) size := file.GetSize()
chunkSize := min(size, 16*utils.MB)
chunkCount := int(size / chunkSize)
lastChunkSize := size % chunkSize
if lastChunkSize > 0 {
chunkCount++
} else {
lastChunkSize = chunkSize
}
// only 1 batch is allowed // only 1 batch is allowed
isMultipart := chunkCount > 1
batchSize := 1 batchSize := 1
getS3UploadUrl := d.getS3Auth getS3UploadUrl := d.getS3Auth
if isMultipart { if chunkCount > 1 {
batchSize = 10 batchSize = 10
getS3UploadUrl = d.getS3PreSignedUrls getS3UploadUrl = d.getS3PreSignedUrls
} }
limited := driver.NewLimitedUploadStream(ctx, file)
for i := 1; i <= chunkCount; i += batchSize { for i := 1; i <= chunkCount; i += batchSize {
if utils.IsCanceled(ctx) { if utils.IsCanceled(ctx) {
return ctx.Err() return ctx.Err()
} }
start := i start := i
end := i + batchSize end := min(i+batchSize, chunkCount+1)
if end > chunkCount+1 {
end = chunkCount + 1
}
s3PreSignedUrls, err := getS3UploadUrl(ctx, upReq, start, end) s3PreSignedUrls, err := getS3UploadUrl(ctx, upReq, start, end)
if err != nil { if err != nil {
return err return err
@ -102,9 +107,9 @@ func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.Fi
} }
curSize := chunkSize curSize := chunkSize
if j == chunkCount { if j == chunkCount {
curSize = file.GetSize() - (int64(chunkCount)-1)*chunkSize curSize = lastChunkSize
} }
err = d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, j, end, io.LimitReader(limited, chunkSize), curSize, false, getS3UploadUrl) err = d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, j, end, io.NewSectionReader(tmpF, chunkSize*int64(j-1), curSize), curSize, false, getS3UploadUrl)
if err != nil { if err != nil {
return err return err
} }
@ -115,12 +120,12 @@ func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.Fi
return d.completeS3(ctx, upReq, file, chunkCount > 1) return d.completeS3(ctx, upReq, file, chunkCount > 1)
} }
func (d *Pan123) uploadS3Chunk(ctx context.Context, upReq *UploadResp, s3PreSignedUrls *S3PreSignedURLs, cur, end int, reader io.Reader, curSize int64, retry bool, getS3UploadUrl func(ctx context.Context, upReq *UploadResp, start int, end int) (*S3PreSignedURLs, error)) error { func (d *Pan123) uploadS3Chunk(ctx context.Context, upReq *UploadResp, s3PreSignedUrls *S3PreSignedURLs, cur, end int, reader *io.SectionReader, curSize int64, retry bool, getS3UploadUrl func(ctx context.Context, upReq *UploadResp, start int, end int) (*S3PreSignedURLs, error)) error {
uploadUrl := s3PreSignedUrls.Data.PreSignedUrls[strconv.Itoa(cur)] uploadUrl := s3PreSignedUrls.Data.PreSignedUrls[strconv.Itoa(cur)]
if uploadUrl == "" { if uploadUrl == "" {
return fmt.Errorf("upload url is empty, s3PreSignedUrls: %+v", s3PreSignedUrls) return fmt.Errorf("upload url is empty, s3PreSignedUrls: %+v", s3PreSignedUrls)
} }
req, err := http.NewRequest("PUT", uploadUrl, reader) req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, reader))
if err != nil { if err != nil {
return err return err
} }
@ -143,6 +148,7 @@ func (d *Pan123) uploadS3Chunk(ctx context.Context, upReq *UploadResp, s3PreSign
} }
s3PreSignedUrls.Data.PreSignedUrls = newS3PreSignedUrls.Data.PreSignedUrls s3PreSignedUrls.Data.PreSignedUrls = newS3PreSignedUrls.Data.PreSignedUrls
// retry // retry
reader.Seek(0, io.SeekStart)
return d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, cur, end, reader, curSize, true, getS3UploadUrl) return d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, cur, end, reader, curSize, true, getS3UploadUrl)
} }
if res.StatusCode != http.StatusOK { if res.StatusCode != http.StatusOK {

View File

@ -2,20 +2,19 @@ package _139
import ( import (
"context" "context"
"encoding/base64"
"encoding/xml" "encoding/xml"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"path" "path"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
streamPkg "github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/cron" "github.com/alist-org/alist/v3/pkg/cron"
"github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/pkg/utils"
"github.com/alist-org/alist/v3/pkg/utils/random" "github.com/alist-org/alist/v3/pkg/utils/random"
@ -72,28 +71,29 @@ func (d *Yun139) Init(ctx context.Context) error {
default: default:
return errs.NotImplement return errs.NotImplement
} }
if d.ref != nil { // if d.ref != nil {
return nil // return nil
} // }
decode, err := base64.StdEncoding.DecodeString(d.Authorization) // decode, err := base64.StdEncoding.DecodeString(d.Authorization)
if err != nil { // if err != nil {
return err // return err
} // }
decodeStr := string(decode) // decodeStr := string(decode)
splits := strings.Split(decodeStr, ":") // splits := strings.Split(decodeStr, ":")
if len(splits) < 2 { // if len(splits) < 2 {
return fmt.Errorf("authorization is invalid, splits < 2") // return fmt.Errorf("authorization is invalid, splits < 2")
} // }
d.Account = splits[1] // d.Account = splits[1]
_, err = d.post("/orchestration/personalCloud/user/v1.0/qryUserExternInfo", base.Json{ // _, err = d.post("/orchestration/personalCloud/user/v1.0/qryUserExternInfo", base.Json{
"qryUserExternInfoReq": base.Json{ // "qryUserExternInfoReq": base.Json{
"commonAccountInfo": base.Json{ // "commonAccountInfo": base.Json{
"account": d.getAccount(), // "account": d.getAccount(),
"accountType": 1, // "accountType": 1,
}, // },
}, // },
}, nil) // }, nil)
return err // return err
return nil
} }
func (d *Yun139) InitReference(storage driver.Driver) error { func (d *Yun139) InitReference(storage driver.Driver) error {
@ -503,23 +503,15 @@ func (d *Yun139) Remove(ctx context.Context, obj model.Obj) error {
} }
} }
const (
_ = iota //ignore first value by assigning to blank identifier
KB = 1 << (10 * iota)
MB
GB
TB
)
func (d *Yun139) getPartSize(size int64) int64 { func (d *Yun139) getPartSize(size int64) int64 {
if d.CustomUploadPartSize != 0 { if d.CustomUploadPartSize != 0 {
return d.CustomUploadPartSize return d.CustomUploadPartSize
} }
// 网盘对于分片数量存在上限 // 网盘对于分片数量存在上限
if size/GB > 30 { if size/utils.GB > 30 {
return 512 * MB return 512 * utils.MB
} }
return 100 * MB return 100 * utils.MB
} }
func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
@ -527,29 +519,28 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr
case MetaPersonalNew: case MetaPersonalNew:
var err error var err error
fullHash := stream.GetHash().GetHash(utils.SHA256) fullHash := stream.GetHash().GetHash(utils.SHA256)
if len(fullHash) <= 0 { if len(fullHash) != utils.SHA256.Width {
tmpF, err := stream.CacheFullInTempFile() _, fullHash, err = streamPkg.CacheFullInTempFileAndHash(stream, utils.SHA256)
if err != nil {
return err
}
fullHash, err = utils.HashFile(utils.SHA256, tmpF)
if err != nil { if err != nil {
return err return err
} }
} }
partInfos := []PartInfo{} size := stream.GetSize()
var partSize = d.getPartSize(stream.GetSize()) var partSize = d.getPartSize(size)
part := (stream.GetSize() + partSize - 1) / partSize part := size / partSize
if part == 0 { if size%partSize > 0 {
part++
} else if part == 0 {
part = 1 part = 1
} }
partInfos := make([]PartInfo, 0, part)
for i := int64(0); i < part; i++ { for i := int64(0); i < part; i++ {
if utils.IsCanceled(ctx) { if utils.IsCanceled(ctx) {
return ctx.Err() return ctx.Err()
} }
start := i * partSize start := i * partSize
byteSize := stream.GetSize() - start byteSize := size - start
if byteSize > partSize { if byteSize > partSize {
byteSize = partSize byteSize = partSize
} }
@ -577,7 +568,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr
"contentType": "application/octet-stream", "contentType": "application/octet-stream",
"parallelUpload": false, "parallelUpload": false,
"partInfos": firstPartInfos, "partInfos": firstPartInfos,
"size": stream.GetSize(), "size": size,
"parentFileId": dstDir.GetID(), "parentFileId": dstDir.GetID(),
"name": stream.GetName(), "name": stream.GetName(),
"type": "file", "type": "file",
@ -630,7 +621,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr
} }
// Progress // Progress
p := driver.NewProgress(stream.GetSize(), up) p := driver.NewProgress(size, up)
rateLimited := driver.NewLimitedUploadStream(ctx, stream) rateLimited := driver.NewLimitedUploadStream(ctx, stream)
// 上传所有分片 // 上传所有分片
@ -790,12 +781,14 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr
return fmt.Errorf("get file upload url failed with result code: %s, message: %s", resp.Data.Result.ResultCode, resp.Data.Result.ResultDesc) return fmt.Errorf("get file upload url failed with result code: %s, message: %s", resp.Data.Result.ResultCode, resp.Data.Result.ResultDesc)
} }
size := stream.GetSize()
// Progress // Progress
p := driver.NewProgress(stream.GetSize(), up) p := driver.NewProgress(size, up)
var partSize = d.getPartSize(size)
var partSize = d.getPartSize(stream.GetSize()) part := size / partSize
part := (stream.GetSize() + partSize - 1) / partSize if size%partSize > 0 {
if part == 0 { part++
} else if part == 0 {
part = 1 part = 1
} }
rateLimited := driver.NewLimitedUploadStream(ctx, stream) rateLimited := driver.NewLimitedUploadStream(ctx, stream)
@ -805,10 +798,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr
} }
start := i * partSize start := i * partSize
byteSize := stream.GetSize() - start byteSize := min(size-start, partSize)
if byteSize > partSize {
byteSize = partSize
}
limitReader := io.LimitReader(rateLimited, byteSize) limitReader := io.LimitReader(rateLimited, byteSize)
// Update Progress // Update Progress
@ -820,7 +810,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr
req = req.WithContext(ctx) req = req.WithContext(ctx)
req.Header.Set("Content-Type", "text/plain;name="+unicode(stream.GetName())) req.Header.Set("Content-Type", "text/plain;name="+unicode(stream.GetName()))
req.Header.Set("contentSize", strconv.FormatInt(stream.GetSize(), 10)) req.Header.Set("contentSize", strconv.FormatInt(size, 10))
req.Header.Set("range", fmt.Sprintf("bytes=%d-%d", start, start+byteSize-1)) req.Header.Set("range", fmt.Sprintf("bytes=%d-%d", start, start+byteSize-1))
req.Header.Set("uploadtaskID", resp.Data.UploadResult.UploadTaskID) req.Header.Set("uploadtaskID", resp.Data.UploadResult.UploadTaskID)
req.Header.Set("rangeType", "0") req.Header.Set("rangeType", "0")

View File

@ -67,6 +67,7 @@ func (d *Yun139) refreshToken() error {
if len(splits) < 3 { if len(splits) < 3 {
return fmt.Errorf("authorization is invalid, splits < 3") return fmt.Errorf("authorization is invalid, splits < 3")
} }
d.Account = splits[1]
strs := strings.Split(splits[2], "|") strs := strings.Split(splits[2], "|")
if len(strs) < 4 { if len(strs) < 4 {
return fmt.Errorf("authorization is invalid, strs < 4") return fmt.Errorf("authorization is invalid, strs < 4")

View File

@ -3,16 +3,15 @@ package _189pc
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/md5"
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"encoding/xml" "encoding/xml"
"fmt" "fmt"
"io" "io"
"math"
"net/http" "net/http"
"net/http/cookiejar" "net/http/cookiejar"
"net/url" "net/url"
"os"
"regexp" "regexp"
"sort" "sort"
"strconv" "strconv"
@ -28,6 +27,7 @@ import (
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/internal/setting" "github.com/alist-org/alist/v3/internal/setting"
"github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/errgroup" "github.com/alist-org/alist/v3/pkg/errgroup"
"github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/pkg/utils"
@ -473,12 +473,8 @@ func (y *Cloud189PC) refreshSession() (err error) {
// 普通上传 // 普通上传
// 无法上传大小为0的文件 // 无法上传大小为0的文件
func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) { func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) {
var sliceSize = partSize(file.GetSize()) size := file.GetSize()
count := int(math.Ceil(float64(file.GetSize()) / float64(sliceSize))) sliceSize := partSize(size)
lastPartSize := file.GetSize() % sliceSize
if file.GetSize() > 0 && lastPartSize == 0 {
lastPartSize = sliceSize
}
params := Params{ params := Params{
"parentFolderId": dstDir.GetID(), "parentFolderId": dstDir.GetID(),
@ -512,22 +508,29 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo
retry.DelayType(retry.BackOffDelay)) retry.DelayType(retry.BackOffDelay))
sem := semaphore.NewWeighted(3) sem := semaphore.NewWeighted(3)
fileMd5 := md5.New() count := int(size / sliceSize)
silceMd5 := md5.New() lastPartSize := size % sliceSize
if lastPartSize > 0 {
count++
} else {
lastPartSize = sliceSize
}
fileMd5 := utils.MD5.NewFunc()
silceMd5 := utils.MD5.NewFunc()
silceMd5Hexs := make([]string, 0, count) silceMd5Hexs := make([]string, 0, count)
teeReader := io.TeeReader(file, io.MultiWriter(fileMd5, silceMd5))
byteSize := sliceSize
for i := 1; i <= count; i++ { for i := 1; i <= count; i++ {
if utils.IsCanceled(upCtx) { if utils.IsCanceled(upCtx) {
break break
} }
byteData := make([]byte, sliceSize)
if i == count { if i == count {
byteData = byteData[:lastPartSize] byteSize = lastPartSize
} }
byteData := make([]byte, byteSize)
// 读取块 // 读取块
silceMd5.Reset() silceMd5.Reset()
if _, err := io.ReadFull(io.TeeReader(file, io.MultiWriter(fileMd5, silceMd5)), byteData); err != io.EOF && err != nil { if _, err := io.ReadFull(teeReader, byteData); err != io.EOF && err != nil {
sem.Release(1) sem.Release(1)
return nil, err return nil, err
} }
@ -607,24 +610,43 @@ func (y *Cloud189PC) RapidUpload(ctx context.Context, dstDir model.Obj, stream m
// 快传 // 快传
func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) { func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) {
tempFile, err := file.CacheFullInTempFile() var (
if err != nil { cache = file.GetFile()
return nil, err tmpF *os.File
err error
)
size := file.GetSize()
if _, ok := cache.(io.ReaderAt); !ok && size > 0 {
tmpF, err = os.CreateTemp(conf.Conf.TempDir, "file-*")
if err != nil {
return nil, err
}
defer func() {
_ = tmpF.Close()
_ = os.Remove(tmpF.Name())
}()
cache = tmpF
} }
sliceSize := partSize(size)
var sliceSize = partSize(file.GetSize()) count := int(size / sliceSize)
count := int(math.Ceil(float64(file.GetSize()) / float64(sliceSize))) lastSliceSize := size % sliceSize
lastSliceSize := file.GetSize() % sliceSize if lastSliceSize > 0 {
if file.GetSize() > 0 && lastSliceSize == 0 { count++
} else {
lastSliceSize = sliceSize lastSliceSize = sliceSize
} }
//step.1 优先计算所需信息 //step.1 优先计算所需信息
byteSize := sliceSize byteSize := sliceSize
fileMd5 := md5.New() fileMd5 := utils.MD5.NewFunc()
silceMd5 := md5.New() sliceMd5 := utils.MD5.NewFunc()
silceMd5Hexs := make([]string, 0, count) sliceMd5Hexs := make([]string, 0, count)
partInfos := make([]string, 0, count) partInfos := make([]string, 0, count)
writers := []io.Writer{fileMd5, sliceMd5}
if tmpF != nil {
writers = append(writers, tmpF)
}
written := int64(0)
for i := 1; i <= count; i++ { for i := 1; i <= count; i++ {
if utils.IsCanceled(ctx) { if utils.IsCanceled(ctx) {
return nil, ctx.Err() return nil, ctx.Err()
@ -634,19 +656,31 @@ func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file mode
byteSize = lastSliceSize byteSize = lastSliceSize
} }
silceMd5.Reset() n, err := utils.CopyWithBufferN(io.MultiWriter(writers...), file, byteSize)
if _, err := utils.CopyWithBufferN(io.MultiWriter(fileMd5, silceMd5), tempFile, byteSize); err != nil && err != io.EOF { written += n
if err != nil && err != io.EOF {
return nil, err return nil, err
} }
md5Byte := silceMd5.Sum(nil) md5Byte := sliceMd5.Sum(nil)
silceMd5Hexs = append(silceMd5Hexs, strings.ToUpper(hex.EncodeToString(md5Byte))) sliceMd5Hexs = append(sliceMd5Hexs, strings.ToUpper(hex.EncodeToString(md5Byte)))
partInfos = append(partInfos, fmt.Sprint(i, "-", base64.StdEncoding.EncodeToString(md5Byte))) partInfos = append(partInfos, fmt.Sprint(i, "-", base64.StdEncoding.EncodeToString(md5Byte)))
sliceMd5.Reset()
}
if tmpF != nil {
if size > 0 && written != size {
return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %d, expect = %d ", written, size)
}
_, err = tmpF.Seek(0, io.SeekStart)
if err != nil {
return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ")
}
} }
fileMd5Hex := strings.ToUpper(hex.EncodeToString(fileMd5.Sum(nil))) fileMd5Hex := strings.ToUpper(hex.EncodeToString(fileMd5.Sum(nil)))
sliceMd5Hex := fileMd5Hex sliceMd5Hex := fileMd5Hex
if file.GetSize() > sliceSize { if size > sliceSize {
sliceMd5Hex = strings.ToUpper(utils.GetMD5EncodeStr(strings.Join(silceMd5Hexs, "\n"))) sliceMd5Hex = strings.ToUpper(utils.GetMD5EncodeStr(strings.Join(sliceMd5Hexs, "\n")))
} }
fullUrl := UPLOAD_URL fullUrl := UPLOAD_URL
@ -712,7 +746,7 @@ func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file mode
} }
// step.4 上传切片 // step.4 上传切片
_, err = y.put(ctx, uploadUrl.RequestURL, uploadUrl.Headers, false, io.NewSectionReader(tempFile, offset, byteSize), isFamily) _, err = y.put(ctx, uploadUrl.RequestURL, uploadUrl.Headers, false, io.NewSectionReader(cache, offset, byteSize), isFamily)
if err != nil { if err != nil {
return err return err
} }
@ -794,11 +828,7 @@ func (y *Cloud189PC) GetMultiUploadUrls(ctx context.Context, isFamily bool, uplo
// 旧版本上传,家庭云不支持覆盖 // 旧版本上传,家庭云不支持覆盖
func (y *Cloud189PC) OldUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) { func (y *Cloud189PC) OldUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) {
tempFile, err := file.CacheFullInTempFile() tempFile, fileMd5, err := stream.CacheFullInTempFileAndHash(file, utils.MD5)
if err != nil {
return nil, err
}
fileMd5, err := utils.HashFile(utils.MD5, tempFile)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,7 +1,6 @@
package aliyundrive_open package aliyundrive_open
import ( import (
"bytes"
"context" "context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
@ -15,6 +14,7 @@ import (
"github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
streamPkg "github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/http_range" "github.com/alist-org/alist/v3/pkg/http_range"
"github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/pkg/utils"
"github.com/avast/retry-go" "github.com/avast/retry-go"
@ -131,16 +131,19 @@ func (d *AliyundriveOpen) calProofCode(stream model.FileStreamer) (string, error
return "", err return "", err
} }
length := proofRange.End - proofRange.Start length := proofRange.End - proofRange.Start
buf := bytes.NewBuffer(make([]byte, 0, length))
reader, err := stream.RangeRead(http_range.Range{Start: proofRange.Start, Length: length}) reader, err := stream.RangeRead(http_range.Range{Start: proofRange.Start, Length: length})
if err != nil { if err != nil {
return "", err return "", err
} }
_, err = utils.CopyWithBufferN(buf, reader, length) buf := make([]byte, length)
n, err := io.ReadFull(reader, buf)
if err == io.ErrUnexpectedEOF {
return "", fmt.Errorf("can't read data, expected=%d, got=%d", len(buf), n)
}
if err != nil { if err != nil {
return "", err return "", err
} }
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil return base64.StdEncoding.EncodeToString(buf), nil
} }
func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
@ -183,25 +186,18 @@ func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream m
_, err, e := d.requestReturnErrResp("/adrive/v1.0/openFile/create", http.MethodPost, func(req *resty.Request) { _, err, e := d.requestReturnErrResp("/adrive/v1.0/openFile/create", http.MethodPost, func(req *resty.Request) {
req.SetBody(createData).SetResult(&createResp) req.SetBody(createData).SetResult(&createResp)
}) })
var tmpF model.File
if err != nil { if err != nil {
if e.Code != "PreHashMatched" || !rapidUpload { if e.Code != "PreHashMatched" || !rapidUpload {
return nil, err return nil, err
} }
log.Debugf("[aliyundrive_open] pre_hash matched, start rapid upload") log.Debugf("[aliyundrive_open] pre_hash matched, start rapid upload")
hi := stream.GetHash() hash := stream.GetHash().GetHash(utils.SHA1)
hash := hi.GetHash(utils.SHA1) if len(hash) != utils.SHA1.Width {
if len(hash) <= 0 { _, hash, err = streamPkg.CacheFullInTempFileAndHash(stream, utils.SHA1)
tmpF, err = stream.CacheFullInTempFile()
if err != nil { if err != nil {
return nil, err return nil, err
} }
hash, err = utils.HashFile(utils.SHA1, tmpF)
if err != nil {
return nil, err
}
} }
delete(createData, "pre_hash") delete(createData, "pre_hash")

View File

@ -6,8 +6,8 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"io" "io"
"math"
"net/url" "net/url"
"os"
stdpath "path" stdpath "path"
"strconv" "strconv"
"time" "time"
@ -15,6 +15,7 @@ import (
"golang.org/x/sync/semaphore" "golang.org/x/sync/semaphore"
"github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/conf"
"github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
@ -185,16 +186,30 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
return newObj, nil return newObj, nil
} }
tempFile, err := stream.CacheFullInTempFile() var (
if err != nil { cache = stream.GetFile()
return nil, err tmpF *os.File
err error
)
if _, ok := cache.(io.ReaderAt); !ok {
tmpF, err = os.CreateTemp(conf.Conf.TempDir, "file-*")
if err != nil {
return nil, err
}
defer func() {
_ = tmpF.Close()
_ = os.Remove(tmpF.Name())
}()
cache = tmpF
} }
streamSize := stream.GetSize() streamSize := stream.GetSize()
sliceSize := d.getSliceSize(streamSize) sliceSize := d.getSliceSize(streamSize)
count := int(math.Max(math.Ceil(float64(streamSize)/float64(sliceSize)), 1)) count := int(streamSize / sliceSize)
lastBlockSize := streamSize % sliceSize lastBlockSize := streamSize % sliceSize
if streamSize > 0 && lastBlockSize == 0 { if lastBlockSize > 0 {
count++
} else {
lastBlockSize = sliceSize lastBlockSize = sliceSize
} }
@ -207,6 +222,11 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
sliceMd5H := md5.New() sliceMd5H := md5.New()
sliceMd5H2 := md5.New() sliceMd5H2 := md5.New()
slicemd5H2Write := utils.LimitWriter(sliceMd5H2, SliceSize) slicemd5H2Write := utils.LimitWriter(sliceMd5H2, SliceSize)
writers := []io.Writer{fileMd5H, sliceMd5H, slicemd5H2Write}
if tmpF != nil {
writers = append(writers, tmpF)
}
written := int64(0)
for i := 1; i <= count; i++ { for i := 1; i <= count; i++ {
if utils.IsCanceled(ctx) { if utils.IsCanceled(ctx) {
@ -215,13 +235,23 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
if i == count { if i == count {
byteSize = lastBlockSize byteSize = lastBlockSize
} }
_, err := utils.CopyWithBufferN(io.MultiWriter(fileMd5H, sliceMd5H, slicemd5H2Write), tempFile, byteSize) n, err := utils.CopyWithBufferN(io.MultiWriter(writers...), stream, byteSize)
written += n
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
return nil, err return nil, err
} }
blockList = append(blockList, hex.EncodeToString(sliceMd5H.Sum(nil))) blockList = append(blockList, hex.EncodeToString(sliceMd5H.Sum(nil)))
sliceMd5H.Reset() sliceMd5H.Reset()
} }
if tmpF != nil {
if written != streamSize {
return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %d, expect = %d ", written, streamSize)
}
_, err = tmpF.Seek(0, io.SeekStart)
if err != nil {
return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ")
}
}
contentMd5 := hex.EncodeToString(fileMd5H.Sum(nil)) contentMd5 := hex.EncodeToString(fileMd5H.Sum(nil))
sliceMd5 := hex.EncodeToString(sliceMd5H2.Sum(nil)) sliceMd5 := hex.EncodeToString(sliceMd5H2.Sum(nil))
blockListStr, _ := utils.Json.MarshalToString(blockList) blockListStr, _ := utils.Json.MarshalToString(blockList)
@ -291,7 +321,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
"partseq": strconv.Itoa(partseq), "partseq": strconv.Itoa(partseq),
} }
err := d.uploadSlice(ctx, params, stream.GetName(), err := d.uploadSlice(ctx, params, stream.GetName(),
driver.NewLimitedUploadStream(ctx, io.NewSectionReader(tempFile, offset, byteSize))) driver.NewLimitedUploadStream(ctx, io.NewSectionReader(cache, offset, byteSize)))
if err != nil { if err != nil {
return err return err
} }

View File

@ -7,7 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math" "os"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
@ -16,6 +16,7 @@ import (
"golang.org/x/sync/semaphore" "golang.org/x/sync/semaphore"
"github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/conf"
"github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
@ -241,11 +242,21 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil
// TODO: // TODO:
// 暂时没有找到妙传方式 // 暂时没有找到妙传方式
var (
// 需要获取完整文件md5,必须支持 io.Seek cache = stream.GetFile()
tempFile, err := stream.CacheFullInTempFile() tmpF *os.File
if err != nil { err error
return nil, err )
if _, ok := cache.(io.ReaderAt); !ok {
tmpF, err = os.CreateTemp(conf.Conf.TempDir, "file-*")
if err != nil {
return nil, err
}
defer func() {
_ = tmpF.Close()
_ = os.Remove(tmpF.Name())
}()
cache = tmpF
} }
const DEFAULT int64 = 1 << 22 const DEFAULT int64 = 1 << 22
@ -253,9 +264,11 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil
// 计算需要的数据 // 计算需要的数据
streamSize := stream.GetSize() streamSize := stream.GetSize()
count := int(math.Ceil(float64(streamSize) / float64(DEFAULT))) count := int(streamSize / DEFAULT)
lastBlockSize := streamSize % DEFAULT lastBlockSize := streamSize % DEFAULT
if lastBlockSize == 0 { if lastBlockSize > 0 {
count++
} else {
lastBlockSize = DEFAULT lastBlockSize = DEFAULT
} }
@ -266,6 +279,11 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil
sliceMd5H := md5.New() sliceMd5H := md5.New()
sliceMd5H2 := md5.New() sliceMd5H2 := md5.New()
slicemd5H2Write := utils.LimitWriter(sliceMd5H2, SliceSize) slicemd5H2Write := utils.LimitWriter(sliceMd5H2, SliceSize)
writers := []io.Writer{fileMd5H, sliceMd5H, slicemd5H2Write}
if tmpF != nil {
writers = append(writers, tmpF)
}
written := int64(0)
for i := 1; i <= count; i++ { for i := 1; i <= count; i++ {
if utils.IsCanceled(ctx) { if utils.IsCanceled(ctx) {
return nil, ctx.Err() return nil, ctx.Err()
@ -273,13 +291,23 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil
if i == count { if i == count {
byteSize = lastBlockSize byteSize = lastBlockSize
} }
_, err := utils.CopyWithBufferN(io.MultiWriter(fileMd5H, sliceMd5H, slicemd5H2Write), tempFile, byteSize) n, err := utils.CopyWithBufferN(io.MultiWriter(writers...), stream, byteSize)
written += n
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
return nil, err return nil, err
} }
sliceMD5List = append(sliceMD5List, hex.EncodeToString(sliceMd5H.Sum(nil))) sliceMD5List = append(sliceMD5List, hex.EncodeToString(sliceMd5H.Sum(nil)))
sliceMd5H.Reset() sliceMd5H.Reset()
} }
if tmpF != nil {
if written != streamSize {
return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %d, expect = %d ", written, streamSize)
}
_, err = tmpF.Seek(0, io.SeekStart)
if err != nil {
return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ")
}
}
contentMd5 := hex.EncodeToString(fileMd5H.Sum(nil)) contentMd5 := hex.EncodeToString(fileMd5H.Sum(nil))
sliceMd5 := hex.EncodeToString(sliceMd5H2.Sum(nil)) sliceMd5 := hex.EncodeToString(sliceMd5H2.Sum(nil))
blockListStr, _ := utils.Json.MarshalToString(sliceMD5List) blockListStr, _ := utils.Json.MarshalToString(sliceMD5List)
@ -291,7 +319,7 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil
"rtype": "1", "rtype": "1",
"ctype": "11", "ctype": "11",
"path": fmt.Sprintf("/%s", stream.GetName()), "path": fmt.Sprintf("/%s", stream.GetName()),
"size": fmt.Sprint(stream.GetSize()), "size": fmt.Sprint(streamSize),
"slice-md5": sliceMd5, "slice-md5": sliceMd5,
"content-md5": contentMd5, "content-md5": contentMd5,
"block_list": blockListStr, "block_list": blockListStr,
@ -343,7 +371,7 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil
r.SetContext(ctx) r.SetContext(ctx)
r.SetQueryParams(uploadParams) r.SetQueryParams(uploadParams)
r.SetFileReader("file", stream.GetName(), r.SetFileReader("file", stream.GetName(),
driver.NewLimitedUploadStream(ctx, io.NewSectionReader(tempFile, offset, byteSize))) driver.NewLimitedUploadStream(ctx, io.NewSectionReader(cache, offset, byteSize)))
}, nil) }, nil)
if err != nil { if err != nil {
return err return err

View File

@ -204,7 +204,7 @@ func (d *Cloudreve) upLocal(ctx context.Context, stream model.FileStreamer, u Up
req.SetContentLength(true) req.SetContentLength(true)
req.SetHeader("Content-Length", strconv.FormatInt(byteSize, 10)) req.SetHeader("Content-Length", strconv.FormatInt(byteSize, 10))
req.SetHeader("User-Agent", d.getUA()) req.SetHeader("User-Agent", d.getUA())
req.SetBody(driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) req.SetBody(driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData)))
}, nil) }, nil)
if err != nil { if err != nil {
break break
@ -239,7 +239,7 @@ func (d *Cloudreve) upRemote(ctx context.Context, stream model.FileStreamer, u U
return err return err
} }
req, err := http.NewRequest("POST", uploadUrl+"?chunk="+strconv.Itoa(chunk), req, err := http.NewRequest("POST", uploadUrl+"?chunk="+strconv.Itoa(chunk),
driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData)))
if err != nil { if err != nil {
return err return err
} }
@ -280,7 +280,7 @@ func (d *Cloudreve) upOneDrive(ctx context.Context, stream model.FileStreamer, u
if err != nil { if err != nil {
return err return err
} }
req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData)))
if err != nil { if err != nil {
return err return err
} }

View File

@ -5,7 +5,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"strings" "strings"
"text/template" "text/template"
"time" "time"
@ -159,7 +158,7 @@ func signCommit(m *map[string]interface{}, entity *openpgp.Entity) (string, erro
if err != nil { if err != nil {
return "", err return "", err
} }
if _, err = io.Copy(armorWriter, &sigBuffer); err != nil { if _, err = utils.CopyWithBuffer(armorWriter, &sigBuffer); err != nil {
return "", err return "", err
} }
_ = armorWriter.Close() _ = armorWriter.Close()

View File

@ -2,7 +2,6 @@ package template
import ( import (
"context" "context"
"crypto/md5"
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
@ -17,6 +16,7 @@ import (
"github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/pkg/utils"
"github.com/foxxorcat/mopan-sdk-go" "github.com/foxxorcat/mopan-sdk-go"
"github.com/go-resty/resty/v2" "github.com/go-resty/resty/v2"
@ -273,23 +273,14 @@ func (d *ILanZou) Remove(ctx context.Context, obj model.Obj) error {
const DefaultPartSize = 1024 * 1024 * 8 const DefaultPartSize = 1024 * 1024 * 8
func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
h := md5.New() etag := s.GetHash().GetHash(utils.MD5)
// need to calculate md5 of the full content var err error
tempFile, err := s.CacheFullInTempFile() if len(etag) != utils.MD5.Width {
if err != nil { _, etag, err = stream.CacheFullInTempFileAndHash(s, utils.MD5)
return nil, err if err != nil {
return nil, err
}
} }
defer func() {
_ = tempFile.Close()
}()
if _, err = utils.CopyWithBuffer(h, tempFile); err != nil {
return nil, err
}
_, err = tempFile.Seek(0, io.SeekStart)
if err != nil {
return nil, err
}
etag := hex.EncodeToString(h.Sum(nil))
// get upToken // get upToken
res, err := d.proved("/7n/getUpToken", http.MethodPost, func(req *resty.Request) { res, err := d.proved("/7n/getUpToken", http.MethodPost, func(req *resty.Request) {
req.SetBody(base.Json{ req.SetBody(base.Json{
@ -309,7 +300,7 @@ func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, s model.FileStreame
key := fmt.Sprintf("disk/%d/%d/%d/%s/%016d", now.Year(), now.Month(), now.Day(), d.account, now.UnixMilli()) key := fmt.Sprintf("disk/%d/%d/%d/%s/%016d", now.Year(), now.Month(), now.Day(), d.account, now.UnixMilli())
reader := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ reader := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{
Reader: &driver.SimpleReaderWithSize{ Reader: &driver.SimpleReaderWithSize{
Reader: tempFile, Reader: s,
Size: s.GetSize(), Size: s.GetSize(),
}, },
UpdateProgress: up, UpdateProgress: up,

View File

@ -269,9 +269,6 @@ func (d *MoPan) Put(ctx context.Context, dstDir model.Obj, stream model.FileStre
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() {
_ = file.Close()
}()
// step.1 // step.1
uploadPartData, err := mopan.InitUploadPartData(ctx, mopan.UpdloadFileParam{ uploadPartData, err := mopan.InitUploadPartData(ctx, mopan.UpdloadFileParam{

View File

@ -227,7 +227,6 @@ func (d *NeteaseMusic) putSongStream(ctx context.Context, stream model.FileStrea
if err != nil { if err != nil {
return err return err
} }
defer tmp.Close()
u := uploader{driver: d, file: tmp} u := uploader{driver: d, file: tmp}

View File

@ -220,7 +220,7 @@ func (d *Onedrive) upBig(ctx context.Context, dstDir model.Obj, stream model.Fil
if err != nil { if err != nil {
return err return err
} }
req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData)))
if err != nil { if err != nil {
return err return err
} }

View File

@ -170,7 +170,7 @@ func (d *OnedriveAPP) upBig(ctx context.Context, dstDir model.Obj, stream model.
if err != nil { if err != nil {
return err return err
} }
req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData)))
if err != nil { if err != nil {
return err return err
} }

View File

@ -7,13 +7,6 @@ import (
"crypto/sha1" "crypto/sha1"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/aliyun/aliyun-oss-go-sdk/oss"
jsoniter "github.com/json-iterator/go"
"github.com/pkg/errors"
"io" "io"
"net/http" "net/http"
"path/filepath" "path/filepath"
@ -24,7 +17,14 @@ import (
"time" "time"
"github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/aliyun/aliyun-oss-go-sdk/oss"
"github.com/go-resty/resty/v2" "github.com/go-resty/resty/v2"
jsoniter "github.com/json-iterator/go"
"github.com/pkg/errors"
) )
var AndroidAlgorithms = []string{ var AndroidAlgorithms = []string{
@ -516,7 +516,7 @@ func (d *PikPak) UploadByMultipart(ctx context.Context, params *S3Params, fileSi
continue continue
} }
b := driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(buf)) b := driver.NewLimitedUploadStream(ctx, bytes.NewReader(buf))
if part, err = bucket.UploadPart(imur, b, chunk.Size, chunk.Number, OssOption(params)...); err == nil { if part, err = bucket.UploadPart(imur, b, chunk.Size, chunk.Number, OssOption(params)...); err == nil {
break break
} }

View File

@ -3,9 +3,8 @@ package quark
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/md5"
"crypto/sha1"
"encoding/hex" "encoding/hex"
"hash"
"io" "io"
"net/http" "net/http"
"time" "time"
@ -14,6 +13,7 @@ import (
"github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
streamPkg "github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/pkg/utils"
"github.com/go-resty/resty/v2" "github.com/go-resty/resty/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -136,33 +136,33 @@ func (d *QuarkOrUC) Remove(ctx context.Context, obj model.Obj) error {
} }
func (d *QuarkOrUC) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { func (d *QuarkOrUC) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
tempFile, err := stream.CacheFullInTempFile() md5Str, sha1Str := stream.GetHash().GetHash(utils.MD5), stream.GetHash().GetHash(utils.SHA1)
if err != nil { var (
return err md5 hash.Hash
sha1 hash.Hash
)
writers := []io.Writer{}
if len(md5Str) != utils.MD5.Width {
md5 = utils.MD5.NewFunc()
writers = append(writers, md5)
} }
defer func() { if len(sha1Str) != utils.SHA1.Width {
_ = tempFile.Close() sha1 = utils.SHA1.NewFunc()
}() writers = append(writers, sha1)
m := md5.New()
_, err = utils.CopyWithBuffer(m, tempFile)
if err != nil {
return err
} }
_, err = tempFile.Seek(0, io.SeekStart)
if err != nil { if len(writers) > 0 {
return err _, err := streamPkg.CacheFullInTempFileAndWriter(stream, io.MultiWriter(writers...))
if err != nil {
return err
}
if md5 != nil {
md5Str = hex.EncodeToString(md5.Sum(nil))
}
if sha1 != nil {
sha1Str = hex.EncodeToString(sha1.Sum(nil))
}
} }
md5Str := hex.EncodeToString(m.Sum(nil))
s := sha1.New()
_, err = utils.CopyWithBuffer(s, tempFile)
if err != nil {
return err
}
_, err = tempFile.Seek(0, io.SeekStart)
if err != nil {
return err
}
sha1Str := hex.EncodeToString(s.Sum(nil))
// pre // pre
pre, err := d.upPre(stream, dstDir.GetID()) pre, err := d.upPre(stream, dstDir.GetID())
if err != nil { if err != nil {
@ -178,27 +178,28 @@ func (d *QuarkOrUC) Put(ctx context.Context, dstDir model.Obj, stream model.File
return nil return nil
} }
// part up // part up
partSize := pre.Metadata.PartSize
var part []byte
md5s := make([]string, 0)
defaultBytes := make([]byte, partSize)
total := stream.GetSize() total := stream.GetSize()
left := total left := total
partSize := int64(pre.Metadata.PartSize)
part := make([]byte, partSize)
count := int(total / partSize)
if total%partSize > 0 {
count++
}
md5s := make([]string, 0, count)
partNumber := 1 partNumber := 1
for left > 0 { for left > 0 {
if utils.IsCanceled(ctx) { if utils.IsCanceled(ctx) {
return ctx.Err() return ctx.Err()
} }
if left > int64(partSize) { if left < partSize {
part = defaultBytes part = part[:left]
} else {
part = make([]byte, left)
} }
_, err := io.ReadFull(tempFile, part) n, err := io.ReadFull(stream, part)
if err != nil { if err != nil {
return err return err
} }
left -= int64(len(part)) left -= int64(n)
log.Debugf("left: %d", left) log.Debugf("left: %d", left)
reader := driver.NewLimitedUploadStream(ctx, bytes.NewReader(part)) reader := driver.NewLimitedUploadStream(ctx, bytes.NewReader(part))
m, err := d.upPart(ctx, pre, stream.GetMimetype(), partNumber, reader) m, err := d.upPart(ctx, pre, stream.GetMimetype(), partNumber, reader)

View File

@ -12,6 +12,7 @@ import (
"github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/pkg/utils"
hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
@ -333,22 +334,17 @@ func (xc *XunLeiCommon) Remove(ctx context.Context, obj model.Obj) error {
} }
func (xc *XunLeiCommon) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { func (xc *XunLeiCommon) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error {
hi := file.GetHash() gcid := file.GetHash().GetHash(hash_extend.GCID)
gcid := hi.GetHash(hash_extend.GCID) var err error
if len(gcid) < hash_extend.GCID.Width { if len(gcid) < hash_extend.GCID.Width {
tFile, err := file.CacheFullInTempFile() _, gcid, err = stream.CacheFullInTempFileAndHash(file, hash_extend.GCID, file.GetSize())
if err != nil {
return err
}
gcid, err = utils.HashFile(hash_extend.GCID, tFile, file.GetSize())
if err != nil { if err != nil {
return err return err
} }
} }
var resp UploadTaskResponse var resp UploadTaskResponse
_, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { _, err = xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) {
r.SetContext(ctx) r.SetContext(ctx)
r.SetBody(&base.Json{ r.SetBody(&base.Json{
"kind": FILE, "kind": FILE,

View File

@ -4,10 +4,15 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http"
"strings"
"github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/internal/op"
streamPkg "github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/pkg/utils"
hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
@ -15,9 +20,6 @@ import (
"github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/go-resty/resty/v2" "github.com/go-resty/resty/v2"
"io"
"net/http"
"strings"
) )
type ThunderBrowser struct { type ThunderBrowser struct {
@ -456,15 +458,10 @@ func (xc *XunLeiBrowserCommon) Remove(ctx context.Context, obj model.Obj) error
} }
func (xc *XunLeiBrowserCommon) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { func (xc *XunLeiBrowserCommon) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
hi := stream.GetHash() gcid := stream.GetHash().GetHash(hash_extend.GCID)
gcid := hi.GetHash(hash_extend.GCID) var err error
if len(gcid) < hash_extend.GCID.Width { if len(gcid) < hash_extend.GCID.Width {
tFile, err := stream.CacheFullInTempFile() _, gcid, err = streamPkg.CacheFullInTempFileAndHash(stream, hash_extend.GCID, stream.GetSize())
if err != nil {
return err
}
gcid, err = utils.HashFile(hash_extend.GCID, tFile, stream.GetSize())
if err != nil { if err != nil {
return err return err
} }
@ -481,7 +478,7 @@ func (xc *XunLeiBrowserCommon) Put(ctx context.Context, dstDir model.Obj, stream
} }
var resp UploadTaskResponse var resp UploadTaskResponse
_, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { _, err = xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) {
r.SetContext(ctx) r.SetContext(ctx)
r.SetBody(&js) r.SetBody(&js)
}, &resp) }, &resp)

View File

@ -3,11 +3,15 @@ package thunderx
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
"strings"
"github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/pkg/utils"
hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
@ -15,8 +19,6 @@ import (
"github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/go-resty/resty/v2" "github.com/go-resty/resty/v2"
"net/http"
"strings"
) )
type ThunderX struct { type ThunderX struct {
@ -364,22 +366,17 @@ func (xc *XunLeiXCommon) Remove(ctx context.Context, obj model.Obj) error {
} }
func (xc *XunLeiXCommon) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { func (xc *XunLeiXCommon) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error {
hi := file.GetHash() gcid := file.GetHash().GetHash(hash_extend.GCID)
gcid := hi.GetHash(hash_extend.GCID) var err error
if len(gcid) < hash_extend.GCID.Width { if len(gcid) < hash_extend.GCID.Width {
tFile, err := file.CacheFullInTempFile() _, gcid, err = stream.CacheFullInTempFileAndHash(file, hash_extend.GCID, file.GetSize())
if err != nil {
return err
}
gcid, err = utils.HashFile(hash_extend.GCID, tFile, file.GetSize())
if err != nil { if err != nil {
return err return err
} }
} }
var resp UploadTaskResponse var resp UploadTaskResponse
_, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { _, err = xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) {
r.SetContext(ctx) r.SetContext(ctx)
r.SetBody(&base.Json{ r.SetBody(&base.Json{
"kind": FILE, "kind": FILE,

View File

@ -10,6 +10,7 @@ import (
"github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/mholt/archives" "github.com/mholt/archives"
) )
@ -73,7 +74,7 @@ func decompress(fsys fs2.FS, filePath, targetPath string, up model.UpdateProgres
return err return err
} }
defer f.Close() defer f.Close()
_, err = io.Copy(f, &stream.ReaderUpdatingProgress{ _, err = utils.CopyWithBuffer(f, &stream.ReaderUpdatingProgress{
Reader: &stream.SimpleReaderWithSize{ Reader: &stream.SimpleReaderWithSize{
Reader: rc, Reader: rc,
Size: stat.Size(), Size: stat.Size(),

View File

@ -1,14 +1,15 @@
package iso9660 package iso9660
import ( import (
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/stream"
"github.com/kdomanski/iso9660"
"io"
"os" "os"
stdpath "path" stdpath "path"
"strings" "strings"
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/kdomanski/iso9660"
) )
func getImage(ss *stream.SeekableStream) (*iso9660.Image, error) { func getImage(ss *stream.SeekableStream) (*iso9660.Image, error) {
@ -66,7 +67,7 @@ func decompress(f *iso9660.File, path string, up model.UpdateProgress) error {
return err return err
} }
defer file.Close() defer file.Close()
_, err = io.Copy(file, &stream.ReaderUpdatingProgress{ _, err = utils.CopyWithBuffer(file, &stream.ReaderUpdatingProgress{
Reader: &stream.SimpleReaderWithSize{ Reader: &stream.SimpleReaderWithSize{
Reader: f.Reader(), Reader: f.Reader(),
Size: f.Size(), Size: f.Size(),

View File

@ -90,9 +90,11 @@ func (t *ArchiveDownloadTask) RunWithoutPushUploadTask() (*ArchiveContentUploadT
t.SetTotalBytes(total) t.SetTotalBytes(total)
t.status = "getting src object" t.status = "getting src object"
for _, s := range ss { for _, s := range ss {
_, err = s.CacheFullInTempFileAndUpdateProgress(func(p float64) { if s.GetFile() == nil {
t.SetProgress((float64(cur) + float64(s.GetSize())*p/100.0) / float64(total)) _, err = stream.CacheFullInTempFileAndUpdateProgress(s, func(p float64) {
}) t.SetProgress((float64(cur) + float64(s.GetSize())*p/100.0) / float64(total))
})
}
cur += s.GetSize() cur += s.GetSize()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -2,6 +2,7 @@ package model
import ( import (
"io" "io"
"os"
"sort" "sort"
"strings" "strings"
"time" "time"
@ -48,7 +49,8 @@ type FileStreamer interface {
RangeRead(http_range.Range) (io.Reader, error) RangeRead(http_range.Range) (io.Reader, error)
//for a non-seekable Stream, if Read is called, this function won't work //for a non-seekable Stream, if Read is called, this function won't work
CacheFullInTempFile() (File, error) CacheFullInTempFile() (File, error)
CacheFullInTempFileAndUpdateProgress(up UpdateProgress) (File, error) SetTmpFile(r *os.File)
GetFile() File
} }
type UpdateProgress func(percentage float64) type UpdateProgress func(percentage float64)

View File

@ -248,8 +248,9 @@ func (d *downloader) sendChunkTask(newConcurrency bool) error {
size: finalSize, size: finalSize,
id: d.nextChunk, id: d.nextChunk,
buf: buf, buf: buf,
newConcurrency: newConcurrency,
} }
ch.newConcurrency = newConcurrency
d.pos += finalSize d.pos += finalSize
d.nextChunk++ d.nextChunk++
d.chunkChannel <- ch d.chunkChannel <- ch

View File

@ -94,27 +94,17 @@ func (f *FileStream) CacheFullInTempFile() (model.File, error) {
f.Add(tmpF) f.Add(tmpF)
f.tmpFile = tmpF f.tmpFile = tmpF
f.Reader = tmpF f.Reader = tmpF
return f.tmpFile, nil return tmpF, nil
} }
func (f *FileStream) CacheFullInTempFileAndUpdateProgress(up model.UpdateProgress) (model.File, error) { func (f *FileStream) GetFile() model.File {
if f.tmpFile != nil { if f.tmpFile != nil {
return f.tmpFile, nil return f.tmpFile
} }
if file, ok := f.Reader.(model.File); ok { if file, ok := f.Reader.(model.File); ok {
return file, nil return file
} }
tmpF, err := utils.CreateTempFile(&ReaderUpdatingProgress{ return nil
Reader: f,
UpdateProgress: up,
}, f.GetSize())
if err != nil {
return nil, err
}
f.Add(tmpF)
f.tmpFile = tmpF
f.Reader = tmpF
return f.tmpFile, nil
} }
const InMemoryBufMaxSize = 10 // Megabytes const InMemoryBufMaxSize = 10 // Megabytes
@ -127,31 +117,36 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
// 参考 internal/net/request.go // 参考 internal/net/request.go
httpRange.Length = f.GetSize() - httpRange.Start httpRange.Length = f.GetSize() - httpRange.Start
} }
if f.peekBuff != nil && httpRange.Start < int64(f.peekBuff.Len()) && httpRange.Start+httpRange.Length-1 < int64(f.peekBuff.Len()) { size := httpRange.Start + httpRange.Length
if f.peekBuff != nil && size <= int64(f.peekBuff.Len()) {
return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil
} }
if f.tmpFile == nil { var cache io.ReaderAt = f.GetFile()
if httpRange.Start == 0 && httpRange.Length <= InMemoryBufMaxSizeBytes && f.peekBuff == nil { if cache == nil {
bufSize := utils.Min(httpRange.Length, f.GetSize()) if size <= InMemoryBufMaxSizeBytes {
newBuf := bytes.NewBuffer(make([]byte, 0, bufSize)) bufSize := min(size, f.GetSize())
n, err := utils.CopyWithBufferN(newBuf, f.Reader, bufSize) // 使用bytes.Buffer作为io.CopyBuffer的写入对象CopyBuffer会调用Buffer.ReadFrom
// 即使被写入的数据量与Buffer.Cap一致Buffer也会扩大
buf := make([]byte, bufSize)
n, err := io.ReadFull(f.Reader, buf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if n != bufSize { if n != int(bufSize) {
return nil, fmt.Errorf("stream RangeRead did not get all data in peek, expect =%d ,actual =%d", bufSize, n) return nil, fmt.Errorf("stream RangeRead did not get all data in peek, expect =%d ,actual =%d", bufSize, n)
} }
f.peekBuff = bytes.NewReader(newBuf.Bytes()) f.peekBuff = bytes.NewReader(buf)
f.Reader = io.MultiReader(f.peekBuff, f.Reader) f.Reader = io.MultiReader(f.peekBuff, f.Reader)
return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil cache = f.peekBuff
} else { } else {
_, err := f.CacheFullInTempFile() var err error
cache, err = f.CacheFullInTempFile()
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
} }
return io.NewSectionReader(f.tmpFile, httpRange.Start, httpRange.Length), nil return io.NewSectionReader(cache, httpRange.Start, httpRange.Length), nil
} }
var _ model.FileStreamer = (*SeekableStream)(nil) var _ model.FileStreamer = (*SeekableStream)(nil)
@ -176,13 +171,13 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error)
if len(fs.Mimetype) == 0 { if len(fs.Mimetype) == 0 {
fs.Mimetype = utils.GetMimeType(fs.Obj.GetName()) fs.Mimetype = utils.GetMimeType(fs.Obj.GetName())
} }
ss := SeekableStream{FileStream: fs, Link: link} ss := &SeekableStream{FileStream: fs, Link: link}
if ss.Reader != nil { if ss.Reader != nil {
result, ok := ss.Reader.(model.File) result, ok := ss.Reader.(model.File)
if ok { if ok {
ss.mFile = result ss.mFile = result
ss.Closers.Add(result) ss.Closers.Add(result)
return &ss, nil return ss, nil
} }
} }
if ss.Link != nil { if ss.Link != nil {
@ -198,7 +193,7 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error)
ss.mFile = mFile ss.mFile = mFile
ss.Reader = mFile ss.Reader = mFile
ss.Closers.Add(mFile) ss.Closers.Add(mFile)
return &ss, nil return ss, nil
} }
if ss.Link.RangeReadCloser != nil { if ss.Link.RangeReadCloser != nil {
ss.rangeReadCloser = &RateLimitRangeReadCloser{ ss.rangeReadCloser = &RateLimitRangeReadCloser{
@ -206,7 +201,7 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error)
Limiter: ServerDownloadLimit, Limiter: ServerDownloadLimit,
} }
ss.Add(ss.rangeReadCloser) ss.Add(ss.rangeReadCloser)
return &ss, nil return ss, nil
} }
if len(ss.Link.URL) > 0 { if len(ss.Link.URL) > 0 {
rrc, err := GetRangeReadCloserFromLink(ss.GetSize(), link) rrc, err := GetRangeReadCloserFromLink(ss.GetSize(), link)
@ -219,10 +214,12 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error)
} }
ss.rangeReadCloser = rrc ss.rangeReadCloser = rrc
ss.Add(rrc) ss.Add(rrc)
return &ss, nil return ss, nil
} }
} }
if fs.Reader != nil {
return ss, nil
}
return nil, fmt.Errorf("illegal seekableStream") return nil, fmt.Errorf("illegal seekableStream")
} }
@ -248,7 +245,7 @@ func (ss *SeekableStream) RangeRead(httpRange http_range.Range) (io.Reader, erro
} }
return rc, nil return rc, nil
} }
return nil, fmt.Errorf("can't find mFile or rangeReadCloser") return ss.FileStream.RangeRead(httpRange)
} }
//func (f *FileStream) GetReader() io.Reader { //func (f *FileStream) GetReader() io.Reader {
@ -278,7 +275,7 @@ func (ss *SeekableStream) CacheFullInTempFile() (model.File, error) {
if ss.tmpFile != nil { if ss.tmpFile != nil {
return ss.tmpFile, nil return ss.tmpFile, nil
} }
if _, ok := ss.mFile.(*os.File); ok { if ss.mFile != nil {
return ss.mFile, nil return ss.mFile, nil
} }
tmpF, err := utils.CreateTempFile(ss, ss.GetSize()) tmpF, err := utils.CreateTempFile(ss, ss.GetSize())
@ -288,27 +285,17 @@ func (ss *SeekableStream) CacheFullInTempFile() (model.File, error) {
ss.Add(tmpF) ss.Add(tmpF)
ss.tmpFile = tmpF ss.tmpFile = tmpF
ss.Reader = tmpF ss.Reader = tmpF
return ss.tmpFile, nil return tmpF, nil
} }
func (ss *SeekableStream) CacheFullInTempFileAndUpdateProgress(up model.UpdateProgress) (model.File, error) { func (ss *SeekableStream) GetFile() model.File {
if ss.tmpFile != nil { if ss.tmpFile != nil {
return ss.tmpFile, nil return ss.tmpFile
} }
if _, ok := ss.mFile.(*os.File); ok { if ss.mFile != nil {
return ss.mFile, nil return ss.mFile
} }
tmpF, err := utils.CreateTempFile(&ReaderUpdatingProgress{ return nil
Reader: ss,
UpdateProgress: up,
}, ss.GetSize())
if err != nil {
return nil, err
}
ss.Add(tmpF)
ss.tmpFile = tmpF
ss.Reader = tmpF
return ss.tmpFile, nil
} }
func (f *FileStream) SetTmpFile(r *os.File) { func (f *FileStream) SetTmpFile(r *os.File) {

View File

@ -2,6 +2,7 @@ package stream
import ( import (
"context" "context"
"encoding/hex"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -96,3 +97,45 @@ func (r *ReaderWithCtx) Close() error {
} }
return nil return nil
} }
func CacheFullInTempFileAndUpdateProgress(stream model.FileStreamer, up model.UpdateProgress) (model.File, error) {
if cache := stream.GetFile(); cache != nil {
up(100)
return cache, nil
}
tmpF, err := utils.CreateTempFile(&ReaderUpdatingProgress{
Reader: stream,
UpdateProgress: up,
}, stream.GetSize())
if err == nil {
stream.SetTmpFile(tmpF)
}
return tmpF, err
}
func CacheFullInTempFileAndWriter(stream model.FileStreamer, w io.Writer) (model.File, error) {
if cache := stream.GetFile(); cache != nil {
_, err := cache.Seek(0, io.SeekStart)
if err == nil {
_, err = utils.CopyWithBuffer(w, cache)
if err == nil {
_, err = cache.Seek(0, io.SeekStart)
}
}
return cache, err
}
tmpF, err := utils.CreateTempFile(io.TeeReader(stream, w), stream.GetSize())
if err == nil {
stream.SetTmpFile(tmpF)
}
return tmpF, err
}
func CacheFullInTempFileAndHash(stream model.FileStreamer, hashType *utils.HashType, params ...any) (model.File, string, error) {
h := hashType.NewFunc(params...)
tmpF, err := CacheFullInTempFileAndWriter(stream, h)
if err != nil {
return nil, "", err
}
return tmpF, hex.EncodeToString(h.Sum(nil)), err
}

View File

@ -1,8 +1,6 @@
package handles package handles
import ( import (
"github.com/alist-org/alist/v3/internal/task"
"github.com/alist-org/alist/v3/pkg/utils"
"io" "io"
"net/url" "net/url"
stdpath "path" stdpath "path"
@ -12,6 +10,8 @@ import (
"github.com/alist-org/alist/v3/internal/fs" "github.com/alist-org/alist/v3/internal/fs"
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/internal/stream"
"github.com/alist-org/alist/v3/internal/task"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/alist-org/alist/v3/server/common" "github.com/alist-org/alist/v3/server/common"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@ -44,7 +44,7 @@ func FsStream(c *gin.Context) {
} }
if !overwrite { if !overwrite {
if res, _ := fs.Get(c, path, &fs.GetArgs{NoLog: true}); res != nil { if res, _ := fs.Get(c, path, &fs.GetArgs{NoLog: true}); res != nil {
_, _ = io.Copy(io.Discard, c.Request.Body) _, _ = utils.CopyWithBuffer(io.Discard, c.Request.Body)
common.ErrorStrResp(c, "file exists", 403) common.ErrorStrResp(c, "file exists", 403)
return return
} }
@ -66,6 +66,10 @@ func FsStream(c *gin.Context) {
if sha256 := c.GetHeader("X-File-Sha256"); sha256 != "" { if sha256 := c.GetHeader("X-File-Sha256"); sha256 != "" {
h[utils.SHA256] = sha256 h[utils.SHA256] = sha256
} }
mimetype := c.GetHeader("Content-Type")
if len(mimetype) == 0 {
mimetype = utils.GetMimeType(name)
}
s := &stream.FileStream{ s := &stream.FileStream{
Obj: &model.Object{ Obj: &model.Object{
Name: name, Name: name,
@ -74,7 +78,7 @@ func FsStream(c *gin.Context) {
HashInfo: utils.NewHashInfoByMap(h), HashInfo: utils.NewHashInfoByMap(h),
}, },
Reader: c.Request.Body, Reader: c.Request.Body,
Mimetype: c.GetHeader("Content-Type"), Mimetype: mimetype,
WebPutAsTask: asTask, WebPutAsTask: asTask,
} }
var t task.TaskExtensionInfo var t task.TaskExtensionInfo
@ -89,6 +93,9 @@ func FsStream(c *gin.Context) {
return return
} }
if t == nil { if t == nil {
if n, _ := io.ReadFull(c.Request.Body, []byte{0}); n == 1 {
_, _ = utils.CopyWithBuffer(io.Discard, c.Request.Body)
}
common.SuccessResp(c) common.SuccessResp(c)
return return
} }
@ -114,7 +121,7 @@ func FsForm(c *gin.Context) {
} }
if !overwrite { if !overwrite {
if res, _ := fs.Get(c, path, &fs.GetArgs{NoLog: true}); res != nil { if res, _ := fs.Get(c, path, &fs.GetArgs{NoLog: true}); res != nil {
_, _ = io.Copy(io.Discard, c.Request.Body) _, _ = utils.CopyWithBuffer(io.Discard, c.Request.Body)
common.ErrorStrResp(c, "file exists", 403) common.ErrorStrResp(c, "file exists", 403)
return return
} }
@ -150,6 +157,10 @@ func FsForm(c *gin.Context) {
if sha256 := c.GetHeader("X-File-Sha256"); sha256 != "" { if sha256 := c.GetHeader("X-File-Sha256"); sha256 != "" {
h[utils.SHA256] = sha256 h[utils.SHA256] = sha256
} }
mimetype := file.Header.Get("Content-Type")
if len(mimetype) == 0 {
mimetype = utils.GetMimeType(name)
}
s := stream.FileStream{ s := stream.FileStream{
Obj: &model.Object{ Obj: &model.Object{
Name: name, Name: name,
@ -158,7 +169,7 @@ func FsForm(c *gin.Context) {
HashInfo: utils.NewHashInfoByMap(h), HashInfo: utils.NewHashInfoByMap(h),
}, },
Reader: f, Reader: f,
Mimetype: file.Header.Get("Content-Type"), Mimetype: mimetype,
WebPutAsTask: asTask, WebPutAsTask: asTask,
} }
var t task.TaskExtensionInfo var t task.TaskExtensionInfo
@ -168,12 +179,7 @@ func FsForm(c *gin.Context) {
}{f} }{f}
t, err = fs.PutAsTask(c, dir, &s) t, err = fs.PutAsTask(c, dir, &s)
} else { } else {
ss, err := stream.NewSeekableStream(s, nil) err = fs.PutDirectly(c, dir, &s, true)
if err != nil {
common.ErrorResp(c, err, 500)
return
}
err = fs.PutDirectly(c, dir, ss, true)
} }
if err != nil { if err != nil {
common.ErrorResp(c, err, 500) common.ErrorResp(c, err, 500)