diff --git a/drivers/baidu_netdisk/driver.go b/drivers/baidu_netdisk/driver.go index 64510dcf..c81225e4 100644 --- a/drivers/baidu_netdisk/driver.go +++ b/drivers/baidu_netdisk/driver.go @@ -1,7 +1,6 @@ package baidu_netdisk import ( - "bytes" "context" "crypto/md5" "encoding/hex" @@ -118,7 +117,6 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F _ = os.Remove(tempFile.Name()) }() var Default int64 = 4 * 1024 * 1024 - defaultByteData := make([]byte, Default) count := int(math.Ceil(float64(stream.GetSize()) / float64(Default))) var SliceSize int64 = 256 * 1024 // cal md5 @@ -130,20 +128,14 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F left := stream.GetSize() for i := 0; i < count; i++ { byteSize := Default - var byteData []byte if left < Default { byteSize = left - byteData = make([]byte, byteSize) - } else { - byteData = defaultByteData } left -= byteSize - _, err = io.ReadFull(tempFile, byteData) + _, err = io.Copy(io.MultiWriter(h1, h2), io.LimitReader(tempFile, byteSize)) if err != nil { return err } - h1.Write(byteData) - h2.Write(byteData) block_list = append(block_list, fmt.Sprintf("\"%s\"", hex.EncodeToString(h2.Sum(nil)))) h2.Reset() } @@ -177,6 +169,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F params := map[string]string{ "method": "precreate", } + log.Debugf("[baidu_netdisk] precreate data: %s", data) var precreateResp PrecreateResp _, err = d.post("/xpan/file", params, data, &precreateResp) if err != nil { @@ -199,24 +192,16 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F return ctx.Err() } byteSize := Default - var byteData []byte if left < Default { byteSize = left - byteData = make([]byte, byteSize) - } else { - byteData = defaultByteData } left -= byteSize - _, err = io.ReadFull(tempFile, byteData) - if err != nil { - return err - } u := "https://d.pcs.baidu.com/rest/2.0/pcs/superfile2" params["partseq"] = strconv.Itoa(partseq) res, err := base.RestyClient.R(). SetContext(ctx). SetQueryParams(params). - SetFileReader("file", stream.GetName(), bytes.NewReader(byteData)). + SetFileReader("file", stream.GetName(), io.LimitReader(tempFile, byteSize)). Post(u) if err != nil { return err diff --git a/drivers/baidu_netdisk/util.go b/drivers/baidu_netdisk/util.go index a4519fdb..5e863036 100644 --- a/drivers/baidu_netdisk/util.go +++ b/drivers/baidu_netdisk/util.go @@ -13,6 +13,7 @@ import ( "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/pkg/utils" "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" ) // do others that not defined in Driver interface @@ -62,16 +63,17 @@ func (d *BaiduNetdisk) request(furl string, method string, callback base.ReqCall if err != nil { return nil, err } + log.Debugf("[baidu_netdisk] req: %s, resp: %s", furl, res.String()) errno := utils.Json.Get(res.Body(), "errno").ToInt() if errno != 0 { - if errno == -6 { + if utils.SliceContains([]int{111, -6}, errno) { err = d.refreshToken() if err != nil { return nil, err } return d.request(furl, method, callback, resp) } - return nil, fmt.Errorf("errno: %d, refer to https://pan.baidu.com/union/doc/", errno) + return nil, fmt.Errorf("req: [%s] ,errno: %d, refer to https://pan.baidu.com/union/doc/", furl, errno) } return res.Body(), nil }