mirror of https://github.com/Xhofe/alist
				
				
				
			
							parent
							
								
									e9cb37122e
								
							
						
					
					
						commit
						7877184bee
					
				| 
						 | 
				
			
			@ -83,7 +83,7 @@ func (d *Pan115) Remove(ctx context.Context, obj model.Obj) error {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func (d *Pan115) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(stream.GetReadCloser())
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -184,7 +184,7 @@ func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr
 | 
			
		|||
	// const DEFAULT int64 = 10485760
 | 
			
		||||
	h := md5.New()
 | 
			
		||||
	// need to calculate md5 of the full content
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(stream.GetReadCloser())
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -545,7 +545,7 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo
 | 
			
		|||
// 快传
 | 
			
		||||
func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
 | 
			
		||||
	// 需要获取完整文件md5,必须支持 io.Seek
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(file.GetReadCloser())
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(file.GetReadCloser(), file.GetSize())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -672,7 +672,7 @@ func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file mode
 | 
			
		|||
// 旧版本上传,家庭云不支持覆盖
 | 
			
		||||
func (y *Cloud189PC) OldUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
 | 
			
		||||
	// 需要获取完整文件md5,必须支持 io.Seek
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(file.GetReadCloser())
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(file.GetReadCloser(), file.GetSize())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -224,7 +224,7 @@ func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream m
 | 
			
		|||
		}
 | 
			
		||||
		log.Debugf("[aliyundrive_open] pre_hash matched, start rapid upload")
 | 
			
		||||
		// convert to local file
 | 
			
		||||
		file, err := utils.CreateTempFile(stream)
 | 
			
		||||
		file, err := utils.CreateTempFile(stream, stream.GetSize())
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,18 +5,19 @@ import (
 | 
			
		|||
	"crypto/md5"
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/alist-org/alist/v3/drivers/base"
 | 
			
		||||
	"github.com/alist-org/alist/v3/internal/driver"
 | 
			
		||||
	"github.com/alist-org/alist/v3/internal/errs"
 | 
			
		||||
	"github.com/alist-org/alist/v3/internal/model"
 | 
			
		||||
	"github.com/alist-org/alist/v3/pkg/utils"
 | 
			
		||||
	"github.com/avast/retry-go"
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
	"io"
 | 
			
		||||
	"math"
 | 
			
		||||
	"os"
 | 
			
		||||
	stdpath "path"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/alist-org/alist/v3/drivers/base"
 | 
			
		||||
	"github.com/alist-org/alist/v3/internal/driver"
 | 
			
		||||
	"github.com/alist-org/alist/v3/internal/model"
 | 
			
		||||
	"github.com/alist-org/alist/v3/pkg/utils"
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type BaiduNetdisk struct {
 | 
			
		||||
| 
						 | 
				
			
			@ -24,6 +25,9 @@ type BaiduNetdisk struct {
 | 
			
		|||
	Addition
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const BaiduFileAPI = "https://d.pcs.baidu.com/rest/2.0/pcs/superfile2"
 | 
			
		||||
const DefaultSliceSize int64 = 4 * 1024 * 1024
 | 
			
		||||
 | 
			
		||||
func (d *BaiduNetdisk) Config() driver.Config {
 | 
			
		||||
	return config
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -108,7 +112,9 @@ func (d *BaiduNetdisk) Remove(ctx context.Context, obj model.Obj) error {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(stream.GetReadCloser())
 | 
			
		||||
	streamSize := stream.GetSize()
 | 
			
		||||
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -116,19 +122,20 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
 | 
			
		|||
		_ = tempFile.Close()
 | 
			
		||||
		_ = os.Remove(tempFile.Name())
 | 
			
		||||
	}()
 | 
			
		||||
	var Default int64 = 4 * 1024 * 1024
 | 
			
		||||
	count := int(math.Ceil(float64(stream.GetSize()) / float64(Default)))
 | 
			
		||||
	var SliceSize int64 = 256 * 1024
 | 
			
		||||
 | 
			
		||||
	count := int(math.Ceil(float64(streamSize) / float64(DefaultSliceSize)))
 | 
			
		||||
	//cal md5 for first 256k data
 | 
			
		||||
	const SliceSize int64 = 256 * 1024
 | 
			
		||||
	// cal md5
 | 
			
		||||
	h1 := md5.New()
 | 
			
		||||
	h2 := md5.New()
 | 
			
		||||
	block_list := make([]string, 0)
 | 
			
		||||
	content_md5 := ""
 | 
			
		||||
	slice_md5 := ""
 | 
			
		||||
	left := stream.GetSize()
 | 
			
		||||
	blockList := make([]string, 0)
 | 
			
		||||
	contentMd5 := ""
 | 
			
		||||
	sliceMd5 := ""
 | 
			
		||||
	left := streamSize
 | 
			
		||||
	for i := 0; i < count; i++ {
 | 
			
		||||
		byteSize := Default
 | 
			
		||||
		if left < Default {
 | 
			
		||||
		byteSize := DefaultSliceSize
 | 
			
		||||
		if left < DefaultSliceSize {
 | 
			
		||||
			byteSize = left
 | 
			
		||||
		}
 | 
			
		||||
		left -= byteSize
 | 
			
		||||
| 
						 | 
				
			
			@ -136,16 +143,16 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
 | 
			
		|||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		block_list = append(block_list, fmt.Sprintf("\"%s\"", hex.EncodeToString(h2.Sum(nil))))
 | 
			
		||||
		blockList = append(blockList, fmt.Sprintf("\"%s\"", hex.EncodeToString(h2.Sum(nil))))
 | 
			
		||||
		h2.Reset()
 | 
			
		||||
	}
 | 
			
		||||
	content_md5 = hex.EncodeToString(h1.Sum(nil))
 | 
			
		||||
	contentMd5 = hex.EncodeToString(h1.Sum(nil))
 | 
			
		||||
	_, err = tempFile.Seek(0, io.SeekStart)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if stream.GetSize() <= SliceSize {
 | 
			
		||||
		slice_md5 = content_md5
 | 
			
		||||
	if streamSize <= SliceSize {
 | 
			
		||||
		sliceMd5 = contentMd5
 | 
			
		||||
	} else {
 | 
			
		||||
		sliceData := make([]byte, SliceSize)
 | 
			
		||||
		_, err = io.ReadFull(tempFile, sliceData)
 | 
			
		||||
| 
						 | 
				
			
			@ -153,19 +160,15 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
 | 
			
		|||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		h2.Write(sliceData)
 | 
			
		||||
		slice_md5 = hex.EncodeToString(h2.Sum(nil))
 | 
			
		||||
		_, err = tempFile.Seek(0, io.SeekStart)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		sliceMd5 = hex.EncodeToString(h2.Sum(nil))
 | 
			
		||||
	}
 | 
			
		||||
	rawPath := stdpath.Join(dstDir.GetPath(), stream.GetName())
 | 
			
		||||
	path := encodeURIComponent(rawPath)
 | 
			
		||||
	block_list_str := fmt.Sprintf("[%s]", strings.Join(block_list, ","))
 | 
			
		||||
	block_list_str := fmt.Sprintf("[%s]", strings.Join(blockList, ","))
 | 
			
		||||
	data := fmt.Sprintf("path=%s&size=%d&isdir=0&autoinit=1&block_list=%s&content-md5=%s&slice-md5=%s",
 | 
			
		||||
		path, stream.GetSize(),
 | 
			
		||||
		path, streamSize,
 | 
			
		||||
		block_list_str,
 | 
			
		||||
		content_md5, slice_md5)
 | 
			
		||||
		contentMd5, sliceMd5)
 | 
			
		||||
	params := map[string]string{
 | 
			
		||||
		"method": "precreate",
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -177,6 +180,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
 | 
			
		|||
	}
 | 
			
		||||
	log.Debugf("%+v", precreateResp)
 | 
			
		||||
	if precreateResp.ReturnType == 2 {
 | 
			
		||||
		//rapid upload, since got md5 match from baidu server
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	params = map[string]string{
 | 
			
		||||
| 
						 | 
				
			
			@ -186,33 +190,49 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
 | 
			
		|||
		"path":         path,
 | 
			
		||||
		"uploadid":     precreateResp.Uploadid,
 | 
			
		||||
	}
 | 
			
		||||
	left = stream.GetSize()
 | 
			
		||||
 | 
			
		||||
	var offset int64 = 0
 | 
			
		||||
	for i, partseq := range precreateResp.BlockList {
 | 
			
		||||
		if utils.IsCanceled(ctx) {
 | 
			
		||||
			return ctx.Err()
 | 
			
		||||
		}
 | 
			
		||||
		byteSize := Default
 | 
			
		||||
		if left < Default {
 | 
			
		||||
			byteSize = left
 | 
			
		||||
		}
 | 
			
		||||
		left -= byteSize
 | 
			
		||||
		u := "https://d.pcs.baidu.com/rest/2.0/pcs/superfile2"
 | 
			
		||||
		params["partseq"] = strconv.Itoa(partseq)
 | 
			
		||||
		res, err := base.RestyClient.R().
 | 
			
		||||
			SetContext(ctx).
 | 
			
		||||
			SetQueryParams(params).
 | 
			
		||||
			SetFileReader("file", stream.GetName(), io.LimitReader(tempFile, byteSize)).
 | 
			
		||||
			Post(u)
 | 
			
		||||
		byteSize := int64(math.Min(float64(streamSize-offset), float64(DefaultSliceSize)))
 | 
			
		||||
		err := retry.Do(func() error {
 | 
			
		||||
			return d.uploadSlice(ctx, ¶ms, stream.GetName(), tempFile, offset, byteSize)
 | 
			
		||||
		},
 | 
			
		||||
			retry.Context(ctx),
 | 
			
		||||
			retry.Attempts(3))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		log.Debugln(res.String())
 | 
			
		||||
		offset += byteSize
 | 
			
		||||
 | 
			
		||||
		if len(precreateResp.BlockList) > 0 {
 | 
			
		||||
			up(i * 100 / len(precreateResp.BlockList))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	_, err = d.create(rawPath, stream.GetSize(), 0, precreateResp.Uploadid, block_list_str)
 | 
			
		||||
	_, err = d.create(rawPath, streamSize, 0, precreateResp.Uploadid, block_list_str)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
func (d *BaiduNetdisk) uploadSlice(ctx context.Context, params *map[string]string, fileName string, file *os.File, offset int64, byteSize int64) error {
 | 
			
		||||
	_, err := file.Seek(offset, io.SeekStart)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res, err := base.RestyClient.R().
 | 
			
		||||
		SetContext(ctx).
 | 
			
		||||
		SetQueryParams(*params).
 | 
			
		||||
		SetFileReader("file", fileName, io.LimitReader(file, byteSize)).
 | 
			
		||||
		Post(BaiduFileAPI)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	log.Debugln(res.RawResponse.Status + res.String())
 | 
			
		||||
	errCode := utils.Json.Get(res.Body(), "error_code").ToInt()
 | 
			
		||||
	errNo := utils.Json.Get(res.Body(), "errno").ToInt()
 | 
			
		||||
	if errCode != 0 || errNo != 0 {
 | 
			
		||||
		return errs.NewErr(errs.StreamIncomplete, "error in uploading to baidu, will retry. response=%s", res.String())
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ driver.Driver = (*BaiduNetdisk)(nil)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,6 +2,7 @@ package baidu_netdisk
 | 
			
		|||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/avast/retry-go"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strconv"
 | 
			
		||||
| 
						 | 
				
			
			@ -51,31 +52,37 @@ func (d *BaiduNetdisk) _refreshToken() error {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func (d *BaiduNetdisk) request(furl string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) {
 | 
			
		||||
	req := base.RestyClient.R()
 | 
			
		||||
	req.SetQueryParam("access_token", d.AccessToken)
 | 
			
		||||
	if callback != nil {
 | 
			
		||||
		callback(req)
 | 
			
		||||
	}
 | 
			
		||||
	if resp != nil {
 | 
			
		||||
		req.SetResult(resp)
 | 
			
		||||
	}
 | 
			
		||||
	res, err := req.Execute(method, furl)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	log.Debugf("[baidu_netdisk] req: %s, resp: %s", furl, res.String())
 | 
			
		||||
	errno := utils.Json.Get(res.Body(), "errno").ToInt()
 | 
			
		||||
	if errno != 0 {
 | 
			
		||||
		if utils.SliceContains([]int{111, -6}, errno) {
 | 
			
		||||
			err = d.refreshToken()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
			return d.request(furl, method, callback, resp)
 | 
			
		||||
	var result []byte
 | 
			
		||||
	err := retry.Do(func() error {
 | 
			
		||||
		req := base.RestyClient.R()
 | 
			
		||||
		req.SetQueryParam("access_token", d.AccessToken)
 | 
			
		||||
		if callback != nil {
 | 
			
		||||
			callback(req)
 | 
			
		||||
		}
 | 
			
		||||
		return nil, fmt.Errorf("req: [%s] ,errno: %d, refer to https://pan.baidu.com/union/doc/", furl, errno)
 | 
			
		||||
	}
 | 
			
		||||
	return res.Body(), nil
 | 
			
		||||
		if resp != nil {
 | 
			
		||||
			req.SetResult(resp)
 | 
			
		||||
		}
 | 
			
		||||
		res, err := req.Execute(method, furl)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		log.Debugf("[baidu_netdisk] req: %s, resp: %s", furl, res.String())
 | 
			
		||||
		errno := utils.Json.Get(res.Body(), "errno").ToInt()
 | 
			
		||||
		if errno != 0 {
 | 
			
		||||
			if utils.SliceContains([]int{111, -6}, errno) {
 | 
			
		||||
				log.Info("refreshing baidu_netdisk token.")
 | 
			
		||||
				err2 := d.refreshToken()
 | 
			
		||||
				if err2 != nil {
 | 
			
		||||
					return err2
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			return fmt.Errorf("req: [%s] ,errno: %d, refer to https://pan.baidu.com/union/doc/", furl, errno)
 | 
			
		||||
		}
 | 
			
		||||
		result = res.Body()
 | 
			
		||||
		return nil
 | 
			
		||||
	},
 | 
			
		||||
		retry.Attempts(3))
 | 
			
		||||
	return result, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *BaiduNetdisk) get(pathname string, params map[string]string, resp interface{}) ([]byte, error) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -212,7 +212,7 @@ func (d *BaiduPhoto) Remove(ctx context.Context, obj model.Obj) error {
 | 
			
		|||
 | 
			
		||||
func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
 | 
			
		||||
	// 需要获取完整文件md5,必须支持 io.Seek
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(stream.GetReadCloser())
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -181,7 +181,7 @@ func (d *MediaTrack) Put(ctx context.Context, dstDir model.Obj, stream model.Fil
 | 
			
		|||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(stream.GetReadCloser())
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -212,7 +212,7 @@ func (d *MoPan) Remove(ctx context.Context, obj model.Obj) error {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func (d *MoPan) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
 | 
			
		||||
	file, err := utils.CreateTempFile(stream)
 | 
			
		||||
	file, err := utils.CreateTempFile(stream, stream.GetSize())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -124,7 +124,7 @@ func (d *PikPak) Remove(ctx context.Context, obj model.Obj) error {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func (d *PikPak) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(stream.GetReadCloser())
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -136,7 +136,7 @@ func (d *QuarkOrUC) Remove(ctx context.Context, obj model.Obj) error {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func (d *QuarkOrUC) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(stream.GetReadCloser())
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -116,7 +116,7 @@ func (d *Terabox) Remove(ctx context.Context, obj model.Obj) error {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func (d *Terabox) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(stream.GetReadCloser())
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -333,7 +333,7 @@ func (xc *XunLeiCommon) Remove(ctx context.Context, obj model.Obj) error {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func (xc *XunLeiCommon) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(stream.GetReadCloser())
 | 
			
		||||
	tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -298,7 +298,7 @@ func (d *WeiYun) Remove(ctx context.Context, obj model.Obj) error {
 | 
			
		|||
 | 
			
		||||
func (d *WeiYun) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
 | 
			
		||||
	if folder, ok := dstDir.(*Folder); ok {
 | 
			
		||||
		file, err := utils.CreateTempFile(stream)
 | 
			
		||||
		file, err := utils.CreateTempFile(stream, stream.GetSize())
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,8 +14,9 @@ var (
 | 
			
		|||
	MoveBetweenTwoStorages = errors.New("can't move files between two storages, try to copy")
 | 
			
		||||
	UploadNotSupported     = errors.New("upload not supported")
 | 
			
		||||
 | 
			
		||||
	MetaNotFound    = errors.New("meta not found")
 | 
			
		||||
	StorageNotFound = errors.New("storage not found")
 | 
			
		||||
	MetaNotFound     = errors.New("meta not found")
 | 
			
		||||
	StorageNotFound  = errors.New("storage not found")
 | 
			
		||||
	StreamIncomplete = errors.New("upload/download stream incomplete, possible network issue")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// NewErr wrap constant error with an extra message
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -27,7 +27,7 @@ func putAsTask(dstDirPath string, file *model.FileStream) error {
 | 
			
		|||
		return errors.WithStack(errs.UploadNotSupported)
 | 
			
		||||
	}
 | 
			
		||||
	if file.NeedStore() {
 | 
			
		||||
		tempFile, err := utils.CreateTempFile(file)
 | 
			
		||||
		tempFile, err := utils.CreateTempFile(file, file.GetSize())
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errors.Wrapf(err, "failed to create temp file")
 | 
			
		||||
		}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,6 +2,7 @@ package utils
 | 
			
		|||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/alist-org/alist/v3/internal/errs"
 | 
			
		||||
	"io"
 | 
			
		||||
	"mime"
 | 
			
		||||
	"os"
 | 
			
		||||
| 
						 | 
				
			
			@ -111,7 +112,7 @@ func CreateNestedFile(path string) (*os.File, error) {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
// CreateTempFile create temp file from io.ReadCloser, and seek to 0
 | 
			
		||||
func CreateTempFile(r io.ReadCloser) (*os.File, error) {
 | 
			
		||||
func CreateTempFile(r io.ReadCloser, size int64) (*os.File, error) {
 | 
			
		||||
	if f, ok := r.(*os.File); ok {
 | 
			
		||||
		return f, nil
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -119,15 +120,19 @@ func CreateTempFile(r io.ReadCloser) (*os.File, error) {
 | 
			
		|||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	_, err = io.Copy(f, r)
 | 
			
		||||
	readBytes, err := io.Copy(f, r)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		_ = os.Remove(f.Name())
 | 
			
		||||
		return nil, err
 | 
			
		||||
		return nil, errs.NewErr(err, "CreateTempFile failed")
 | 
			
		||||
	}
 | 
			
		||||
	if size != 0 && readBytes != size {
 | 
			
		||||
		_ = os.Remove(f.Name())
 | 
			
		||||
		return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %s, expect = %s ", readBytes, size)
 | 
			
		||||
	}
 | 
			
		||||
	_, err = f.Seek(0, io.SeekStart)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		_ = os.Remove(f.Name())
 | 
			
		||||
		return nil, err
 | 
			
		||||
		return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ")
 | 
			
		||||
	}
 | 
			
		||||
	return f, nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue