From 2be0c3d1a088d2c74bb429c9d6072a73bd30fb1b Mon Sep 17 00:00:00 2001 From: j2rong4cn <36783515+j2rong4cn@users.noreply.github.com> Date: Mon, 27 Jan 2025 20:08:39 +0800 Subject: [PATCH] feat(alias): add `DownloadConcurrency` and `DownloadPartSize` option (#7829) * fix(net): goroutine logic bug (AlistGo/alist#7215) * Fix goroutine logic bug * Fix bug --------- Co-authored-by: hpy hs * perf(net): sequential and dynamic concurrency * fix(net): incorrect error return * feat(alias): add `DownloadConcurrency` and `DownloadPartSize` option * feat(net): add `ConcurrencyLimit` * pref(net): create `chunk` on demand * refactor * refactor * fix(net): `r.Closers.Add` has no effect * refactor --------- Co-authored-by: hpy hs --- drivers/alias/driver.go | 10 + drivers/alias/meta.go | 6 +- drivers/alias/util.go | 10 +- drivers/crypt/driver.go | 7 +- drivers/github/driver.go | 15 +- drivers/halalcloud/driver.go | 16 +- drivers/mega/driver.go | 4 +- drivers/netease_music/types.go | 1 - drivers/netease_music/upload.go | 2 +- drivers/quqi/util.go | 4 +- internal/bootstrap/config.go | 4 + internal/conf/config.go | 2 + internal/model/args.go | 11 +- internal/net/request.go | 364 +++++++++++++----- internal/net/serve.go | 35 +- internal/net/util.go | 3 +- .../offline_download/transmission/client.go | 4 +- internal/stream/stream.go | 5 +- internal/stream/util.go | 44 +-- server/common/proxy.go | 15 +- server/handles/archive.go | 9 +- server/handles/down.go | 9 +- server/s3/backend.go | 52 +-- server/webdav/webdav.go | 2 +- 24 files changed, 396 insertions(+), 238 deletions(-) diff --git a/drivers/alias/driver.go b/drivers/alias/driver.go index 1b439a2c..16215c8e 100644 --- a/drivers/alias/driver.go +++ b/drivers/alias/driver.go @@ -110,6 +110,16 @@ func (d *Alias) Link(ctx context.Context, file model.Obj, args model.LinkArgs) ( for _, dst := range dsts { link, err := d.link(ctx, dst, sub, args) if err == nil { + if !args.Redirect && len(link.URL) > 0 { + // 正常情况下 多并发 仅支持返回URL的驱动 + // alias套娃alias 可以让crypt、mega等驱动(不返回URL的) 支持并发 + if d.DownloadConcurrency > 0 { + link.Concurrency = d.DownloadConcurrency + } + if d.DownloadPartSize > 0 { + link.PartSize = d.DownloadPartSize * utils.KB + } + } return link, nil } } diff --git a/drivers/alias/meta.go b/drivers/alias/meta.go index 45b88575..ed657a5d 100644 --- a/drivers/alias/meta.go +++ b/drivers/alias/meta.go @@ -9,8 +9,10 @@ type Addition struct { // Usually one of two // driver.RootPath // define other - Paths string `json:"paths" required:"true" type:"text"` - ProtectSameName bool `json:"protect_same_name" default:"true" required:"false" help:"Protects same-name files from Delete or Rename"` + Paths string `json:"paths" required:"true" type:"text"` + ProtectSameName bool `json:"protect_same_name" default:"true" required:"false" help:"Protects same-name files from Delete or Rename"` + DownloadConcurrency int `json:"download_concurrency" default:"0" required:"false" type:"number" help:"Need to enable proxy"` + DownloadPartSize int `json:"download_part_size" default:"0" type:"number" required:"false" help:"Need to enable proxy. Unit: KB"` } var config = driver.Config{ diff --git a/drivers/alias/util.go b/drivers/alias/util.go index c0e9081b..ee17b622 100644 --- a/drivers/alias/util.go +++ b/drivers/alias/util.go @@ -9,6 +9,7 @@ import ( "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/fs" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/internal/sign" "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/common" @@ -94,10 +95,15 @@ func (d *Alias) list(ctx context.Context, dst, sub string, args *fs.ListArgs) ([ func (d *Alias) link(ctx context.Context, dst, sub string, args model.LinkArgs) (*model.Link, error) { reqPath := stdpath.Join(dst, sub) - storage, err := fs.GetStorage(reqPath, &fs.GetStoragesArgs{}) + // 参考 crypt 驱动 + storage, reqActualPath, err := op.GetStorageAndActualPath(reqPath) if err != nil { return nil, err } + if _, ok := storage.(*Alias); !ok && !args.Redirect { + link, _, err := op.Link(ctx, storage, reqActualPath, args) + return link, err + } _, err = fs.Get(ctx, reqPath, &fs.GetArgs{NoLog: true}) if err != nil { return nil, err @@ -114,7 +120,7 @@ func (d *Alias) link(ctx context.Context, dst, sub string, args model.LinkArgs) } return link, nil } - link, _, err := fs.Link(ctx, reqPath, args) + link, _, err := op.Link(ctx, storage, reqActualPath, args) return link, err } diff --git a/drivers/crypt/driver.go b/drivers/crypt/driver.go index b6115896..e6f253d1 100644 --- a/drivers/crypt/driver.go +++ b/drivers/crypt/driver.go @@ -275,7 +275,6 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) ( rrc = converted } if rrc != nil { - //remoteRangeReader, err := remoteReader, err := rrc.RangeRead(ctx, http_range.Range{Start: underlyingOffset, Length: length}) remoteClosers.AddClosers(rrc.GetClosers()) if err != nil { @@ -288,10 +287,8 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) ( if err != nil { return nil, err } - //remoteClosers.Add(remoteLink.MFile) - //keep reuse same MFile and close at last. - remoteClosers.Add(remoteLink.MFile) - return io.NopCloser(remoteLink.MFile), nil + // 可以直接返回,读取完也不会调用Close,直到连接断开Close + return remoteLink.MFile, nil } return nil, errs.NotSupport diff --git a/drivers/github/driver.go b/drivers/github/driver.go index ea8f6276..eed06882 100644 --- a/drivers/github/driver.go +++ b/drivers/github/driver.go @@ -5,6 +5,13 @@ import ( "encoding/base64" "errors" "fmt" + "io" + "net/http" + stdpath "path" + "strings" + "sync" + "text/template" + "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" @@ -12,12 +19,6 @@ import ( "github.com/alist-org/alist/v3/pkg/utils" "github.com/go-resty/resty/v2" log "github.com/sirupsen/logrus" - "io" - "net/http" - stdpath "path" - "strings" - "sync" - "text/template" ) type Github struct { @@ -656,7 +657,7 @@ func (d *Github) putBlob(ctx context.Context, stream model.FileStreamer, up driv contentReader, contentWriter := io.Pipe() go func() { encoder := base64.NewEncoder(base64.StdEncoding, contentWriter) - if _, err := io.Copy(encoder, stream); err != nil { + if _, err := utils.CopyWithBuffer(encoder, stream); err != nil { _ = contentWriter.CloseWithError(err) return } diff --git a/drivers/halalcloud/driver.go b/drivers/halalcloud/driver.go index 08bb3808..d3235828 100644 --- a/drivers/halalcloud/driver.go +++ b/drivers/halalcloud/driver.go @@ -4,12 +4,17 @@ import ( "context" "crypto/sha1" "fmt" + "io" + "net/url" + "path" + "strconv" + "time" + "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/pkg/http_range" - "github.com/alist-org/alist/v3/pkg/utils" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" @@ -19,11 +24,6 @@ import ( pubUserFile "github.com/city404/v6-public-rpc-proto/go/v6/userfile" "github.com/rclone/rclone/lib/readers" "github.com/zzzhr1990/go-common-entity/userfile" - "io" - "net/url" - "path" - "strconv" - "time" ) type HalalCloud struct { @@ -251,7 +251,6 @@ func (d *HalalCloud) getLink(ctx context.Context, file model.Obj, args model.Lin size := result.FileSize chunks := getChunkSizes(result.Sizes) - var finalClosers utils.Closers resultRangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { length := httpRange.Length if httpRange.Length >= 0 && httpRange.Start+httpRange.Length >= size { @@ -269,7 +268,6 @@ func (d *HalalCloud) getLink(ctx context.Context, file model.Obj, args model.Lin sha: result.Sha1, shaTemp: sha1.New(), } - finalClosers.Add(oo) return readers.NewLimitedReadCloser(oo, length), nil } @@ -281,7 +279,7 @@ func (d *HalalCloud) getLink(ctx context.Context, file model.Obj, args model.Lin duration = time.Until(time.Now().Add(time.Hour)) } - resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader, Closers: finalClosers} + resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader} return &model.Link{ RangeReadCloser: resultRangeReadCloser, Expiration: &duration, diff --git a/drivers/mega/driver.go b/drivers/mega/driver.go index 162aeef3..198c1f98 100644 --- a/drivers/mega/driver.go +++ b/drivers/mega/driver.go @@ -84,7 +84,6 @@ func (d *Mega) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (* //} size := file.GetSize() - var finalClosers utils.Closers resultRangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { length := httpRange.Length if httpRange.Length >= 0 && httpRange.Start+httpRange.Length >= size { @@ -103,11 +102,10 @@ func (d *Mega) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (* d: down, skip: httpRange.Start, } - finalClosers.Add(oo) return readers.NewLimitedReadCloser(oo, length), nil } - resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader, Closers: finalClosers} + resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader} resultLink := &model.Link{ RangeReadCloser: resultRangeReadCloser, } diff --git a/drivers/netease_music/types.go b/drivers/netease_music/types.go index edbd40ee..0e156ad1 100644 --- a/drivers/netease_music/types.go +++ b/drivers/netease_music/types.go @@ -64,7 +64,6 @@ func (lrc *LyricObj) getLyricLink() *model.Link { sr := io.NewSectionReader(reader, httpRange.Start, httpRange.Length) return io.NopCloser(sr), nil }, - Closers: utils.EmptyClosers(), }, } } diff --git a/drivers/netease_music/upload.go b/drivers/netease_music/upload.go index ece496b3..7f580bd1 100644 --- a/drivers/netease_music/upload.go +++ b/drivers/netease_music/upload.go @@ -47,7 +47,7 @@ func (u *uploader) init(stream model.FileStreamer) error { } h := md5.New() - io.Copy(h, stream) + utils.CopyWithBuffer(h, stream) u.md5 = hex.EncodeToString(h.Sum(nil)) _, err := u.file.Seek(0, io.SeekStart) if err != nil { diff --git a/drivers/quqi/util.go b/drivers/quqi/util.go index c025f6ee..c57e641b 100644 --- a/drivers/quqi/util.go +++ b/drivers/quqi/util.go @@ -300,9 +300,7 @@ func (d *Quqi) linkFromCDN(id string) (*model.Link, error) { bufferReader := bufio.NewReader(decryptReader) bufferReader.Discard(int(decryptedOffset)) - return utils.NewReadCloser(bufferReader, func() error { - return nil - }), nil + return io.NopCloser(bufferReader), nil } return &model.Link{ diff --git a/internal/bootstrap/config.go b/internal/bootstrap/config.go index 38b1aa9e..db3e2094 100644 --- a/internal/bootstrap/config.go +++ b/internal/bootstrap/config.go @@ -9,6 +9,7 @@ import ( "github.com/alist-org/alist/v3/cmd/flags" "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/net" "github.com/alist-org/alist/v3/pkg/utils" "github.com/caarlos0/env/v9" log "github.com/sirupsen/logrus" @@ -63,6 +64,9 @@ func InitConfig() { log.Fatalf("update config struct error: %+v", err) } } + if conf.Conf.MaxConcurrency > 0 { + net.DefaultConcurrencyLimit = &net.ConcurrencyLimit{Limit: conf.Conf.MaxConcurrency} + } if !conf.Conf.Force { confFromEnv() } diff --git a/internal/conf/config.go b/internal/conf/config.go index 4f5c2ae0..39b23227 100644 --- a/internal/conf/config.go +++ b/internal/conf/config.go @@ -106,6 +106,7 @@ type Config struct { Log LogConfig `json:"log"` DelayedStart int `json:"delayed_start" env:"DELAYED_START"` MaxConnections int `json:"max_connections" env:"MAX_CONNECTIONS"` + MaxConcurrency int `json:"max_concurrency" env:"MAX_CONCURRENCY"` TlsInsecureSkipVerify bool `json:"tls_insecure_skip_verify" env:"TLS_INSECURE_SKIP_VERIFY"` Tasks TasksConfig `json:"tasks" envPrefix:"TASKS_"` Cors Cors `json:"cors" envPrefix:"CORS_"` @@ -151,6 +152,7 @@ func DefaultConfig() *Config { MaxAge: 28, }, MaxConnections: 0, + MaxConcurrency: 64, TlsInsecureSkipVerify: true, Tasks: TasksConfig{ Download: TaskConfig{ diff --git a/internal/model/args.go b/internal/model/args.go index a9feeb20..f29c7e45 100644 --- a/internal/model/args.go +++ b/internal/model/args.go @@ -17,10 +17,11 @@ type ListArgs struct { } type LinkArgs struct { - IP string - Header http.Header - Type string - HttpReq *http.Request + IP string + Header http.Header + Type string + HttpReq *http.Request + Redirect bool } type Link struct { @@ -87,7 +88,7 @@ type RangeReadCloser struct { utils.Closers } -func (r RangeReadCloser) RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { +func (r *RangeReadCloser) RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { rc, err := r.RangeReader(ctx, httpRange) r.Closers.Add(rc) return rc, err diff --git a/internal/net/request.go b/internal/net/request.go index 1a7405e4..d2f3028f 100644 --- a/internal/net/request.go +++ b/internal/net/request.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "io" - "math" "net/http" "strconv" "strings" @@ -21,7 +20,7 @@ import ( // DefaultDownloadPartSize is the default range of bytes to get at a time when // using Download(). -const DefaultDownloadPartSize = 1024 * 1024 * 10 +const DefaultDownloadPartSize = utils.MB * 10 // DefaultDownloadConcurrency is the default number of goroutines to spin up // when using Download(). @@ -30,6 +29,8 @@ const DefaultDownloadConcurrency = 2 // DefaultPartBodyMaxRetries is the default number of retries to make when a part fails to download. const DefaultPartBodyMaxRetries = 3 +var DefaultConcurrencyLimit *ConcurrencyLimit + type Downloader struct { PartSize int @@ -44,15 +45,15 @@ type Downloader struct { //RequestParam HttpRequestParams HttpClient HttpRequestFunc + + *ConcurrencyLimit } type HttpRequestFunc func(ctx context.Context, params *HttpRequestParams) (*http.Response, error) func NewDownloader(options ...func(*Downloader)) *Downloader { - d := &Downloader{ - HttpClient: DefaultHttpRequestFunc, - PartSize: DefaultDownloadPartSize, + d := &Downloader{ //允许不设置的选项 PartBodyMaxRetries: DefaultPartBodyMaxRetries, - Concurrency: DefaultDownloadConcurrency, + ConcurrencyLimit: DefaultConcurrencyLimit, } for _, option := range options { option(d) @@ -74,16 +75,16 @@ func (d Downloader) Download(ctx context.Context, p *HttpRequestParams) (readClo impl := downloader{params: &finalP, cfg: d, ctx: ctx} // Ensures we don't need nil checks later on - - impl.partBodyMaxRetries = d.PartBodyMaxRetries - + // 必需的选项 if impl.cfg.Concurrency == 0 { impl.cfg.Concurrency = DefaultDownloadConcurrency } - if impl.cfg.PartSize == 0 { impl.cfg.PartSize = DefaultDownloadPartSize } + if impl.cfg.HttpClient == nil { + impl.cfg.HttpClient = DefaultHttpRequestFunc + } return impl.download() } @@ -91,7 +92,7 @@ func (d Downloader) Download(ctx context.Context, p *HttpRequestParams) (readClo // downloader is the implementation structure used internally by Downloader. type downloader struct { ctx context.Context - cancel context.CancelFunc + cancel context.CancelCauseFunc cfg Downloader params *HttpRequestParams //http request params @@ -101,38 +102,78 @@ type downloader struct { m sync.Mutex nextChunk int //next chunk id - chunks []chunk bufs []*Buf - //totalBytes int64 - written int64 //total bytes of file downloaded from remote - err error + written int64 //total bytes of file downloaded from remote + err error - partBodyMaxRetries int + concurrency int //剩余的并发数,递减。到0时停止并发 + maxPart int //有多少个分片 + pos int64 + maxPos int64 + m2 sync.Mutex + readingID int // 正在被读取的id +} + +type ConcurrencyLimit struct { + _m sync.Mutex + Limit int // 需要大于0 +} + +var ErrExceedMaxConcurrency = fmt.Errorf("ExceedMaxConcurrency") + +func (l *ConcurrencyLimit) sub() error { + l._m.Lock() + defer l._m.Unlock() + if l.Limit-1 < 0 { + return ErrExceedMaxConcurrency + } + l.Limit-- + // log.Debugf("ConcurrencyLimit.sub: %d", l.Limit) + return nil +} +func (l *ConcurrencyLimit) add() { + l._m.Lock() + defer l._m.Unlock() + l.Limit++ + // log.Debugf("ConcurrencyLimit.add: %d", l.Limit) +} + +// 检测是否超过限制 +func (d *downloader) concurrencyCheck() error { + if d.cfg.ConcurrencyLimit != nil { + return d.cfg.ConcurrencyLimit.sub() + } + return nil +} +func (d *downloader) concurrencyFinish() { + if d.cfg.ConcurrencyLimit != nil { + d.cfg.ConcurrencyLimit.add() + } } // download performs the implementation of the object download across ranged GETs. func (d *downloader) download() (io.ReadCloser, error) { - d.ctx, d.cancel = context.WithCancel(d.ctx) + if err := d.concurrencyCheck(); err != nil { + return nil, err + } + d.ctx, d.cancel = context.WithCancelCause(d.ctx) - pos := d.params.Range.Start - maxPos := d.params.Range.Start + d.params.Range.Length - id := 0 - for pos < maxPos { - finalSize := int64(d.cfg.PartSize) - //check boundary - if pos+finalSize > maxPos { - finalSize = maxPos - pos - } - c := chunk{start: pos, size: finalSize, id: id} - d.chunks = append(d.chunks, c) - pos += finalSize - id++ + maxPart := int(d.params.Range.Length / int64(d.cfg.PartSize)) + if d.params.Range.Length%int64(d.cfg.PartSize) > 0 { + maxPart++ } - if len(d.chunks) < d.cfg.Concurrency { - d.cfg.Concurrency = len(d.chunks) + if maxPart < d.cfg.Concurrency { + d.cfg.Concurrency = maxPart } + log.Debugf("cfgConcurrency:%d", d.cfg.Concurrency) if d.cfg.Concurrency == 1 { + if d.cfg.ConcurrencyLimit != nil { + go func() { + <-d.ctx.Done() + d.concurrencyFinish() + }() + } resp, err := d.cfg.HttpClient(d.ctx, d.params) if err != nil { return nil, err @@ -143,61 +184,114 @@ func (d *downloader) download() (io.ReadCloser, error) { // workers d.chunkChannel = make(chan chunk, d.cfg.Concurrency) - for i := 0; i < d.cfg.Concurrency; i++ { - buf := NewBuf(d.ctx, d.cfg.PartSize, i) - d.bufs = append(d.bufs, buf) - go d.downloadPart() - } - // initial tasks - for i := 0; i < d.cfg.Concurrency; i++ { - d.sendChunkTask() - } + d.maxPart = maxPart + d.pos = d.params.Range.Start + d.maxPos = d.params.Range.Start + d.params.Range.Length + d.concurrency = d.cfg.Concurrency + d.sendChunkTask(true) - var rc io.ReadCloser = NewMultiReadCloser(d.chunks[0].buf, d.interrupt, d.finishBuf) + var rc io.ReadCloser = NewMultiReadCloser(d.bufs[0], d.interrupt, d.finishBuf) // Return error return rc, d.err } -func (d *downloader) sendChunkTask() *chunk { - ch := &d.chunks[d.nextChunk] - ch.buf = d.getBuf(d.nextChunk) - ch.buf.Reset(int(ch.size)) - d.chunkChannel <- *ch - d.nextChunk++ - return ch + +func (d *downloader) sendChunkTask(newConcurrency bool) error { + d.m.Lock() + defer d.m.Unlock() + isNewBuf := d.concurrency > 0 + if newConcurrency { + if d.concurrency <= 0 { + return nil + } + if d.nextChunk > 0 { // 第一个不检查,因为已经检查过了 + if err := d.concurrencyCheck(); err != nil { + return err + } + } + d.concurrency-- + go d.downloadPart() + } + + var buf *Buf + if isNewBuf { + buf = NewBuf(d.ctx, d.cfg.PartSize) + d.bufs = append(d.bufs, buf) + } else { + buf = d.getBuf(d.nextChunk) + } + + if d.pos < d.maxPos { + finalSize := int64(d.cfg.PartSize) + switch d.nextChunk { + case 0: + // 最小分片在前面有助视频播放? + firstSize := d.params.Range.Length % finalSize + if firstSize > 0 { + minSize := finalSize / 2 + if firstSize < minSize { // 最小分片太小就调整到一半 + finalSize = minSize + } else { + finalSize = firstSize + } + } + case 1: + firstSize := d.params.Range.Length % finalSize + minSize := finalSize / 2 + if firstSize > 0 && firstSize < minSize { + finalSize += firstSize - minSize + } + } + buf.Reset(int(finalSize)) + ch := chunk{ + start: d.pos, + size: finalSize, + id: d.nextChunk, + buf: buf, + } + ch.newConcurrency = newConcurrency + d.pos += finalSize + d.nextChunk++ + d.chunkChannel <- ch + return nil + } + return nil } // when the final reader Close, we interrupt func (d *downloader) interrupt() error { - - d.cancel() if d.written != d.params.Range.Length { log.Debugf("Downloader interrupt before finish") if d.getErr() == nil { d.setErr(fmt.Errorf("interrupted")) } } + d.cancel(d.err) defer func() { close(d.chunkChannel) for _, buf := range d.bufs { buf.Close() } + if d.concurrency > 0 { + d.concurrency = -d.concurrency + } + log.Debugf("maxConcurrency:%d", d.cfg.Concurrency+d.concurrency) }() return d.err } func (d *downloader) getBuf(id int) (b *Buf) { - - return d.bufs[id%d.cfg.Concurrency] + return d.bufs[id%len(d.bufs)] } -func (d *downloader) finishBuf(id int) (isLast bool, buf *Buf) { - if id >= len(d.chunks)-1 { +func (d *downloader) finishBuf(id int) (isLast bool, nextBuf *Buf) { + id++ + if id >= d.maxPart { return true, nil } - if d.nextChunk > id+1 { - return false, d.getBuf(id + 1) - } - ch := d.sendChunkTask() - return false, ch.buf + + d.sendChunkTask(false) + + d.readingID = id + return false, d.getBuf(id) } // downloadPart is an individual goroutine worker reading from the ch channel @@ -212,58 +306,119 @@ func (d *downloader) downloadPart() { if d.getErr() != nil { // Drain the channel if there is an error, to prevent deadlocking // of download producer. - continue + break } - log.Debugf("downloadPart tried to get chunk") if err := d.downloadChunk(&c); err != nil { + if err == errCancelConcurrency { + break + } + if err == context.Canceled { + if e := context.Cause(d.ctx); e != nil { + err = e + } + } d.setErr(err) + d.cancel(err) } } + d.concurrencyFinish() } // downloadChunk downloads the chunk func (d *downloader) downloadChunk(ch *chunk) error { - log.Debugf("start new chunk %+v buffer_id =%d", ch, ch.id) + log.Debugf("start chunk_%d, %+v", ch.id, ch) + params := d.getParamsFromChunk(ch) var n int64 var err error - params := d.getParamsFromChunk(ch) - for retry := 0; retry <= d.partBodyMaxRetries; retry++ { + for retry := 0; retry <= d.cfg.PartBodyMaxRetries; retry++ { if d.getErr() != nil { - return d.getErr() + return nil } n, err = d.tryDownloadChunk(params, ch) if err == nil { + d.incrWritten(n) + log.Debugf("chunk_%d downloaded", ch.id) break } - // Check if the returned error is an errReadingBody. - // If err is errReadingBody this indicates that an error - // occurred while copying the http response body. + if d.getErr() != nil { + return nil + } + if utils.IsCanceled(d.ctx) { + return d.ctx.Err() + } + // Check if the returned error is an errNeedRetry. // If this occurs we unwrap the err to set the underlying error // and attempt any remaining retries. - if bodyErr, ok := err.(*errReadingBody); ok { - err = bodyErr.Unwrap() + if e, ok := err.(*errNeedRetry); ok { + err = e.Unwrap() + if n > 0 { + // 测试:下载时 断开 alist向云盘发起的下载连接 + // 校验:下载完后校验文件哈希值 一致 + d.incrWritten(n) + ch.start += n + ch.size -= n + params.Range.Start = ch.start + params.Range.Length = ch.size + } + log.Warnf("err chunk_%d, object part download error %s, retrying attempt %d. %v", + ch.id, params.URL, retry, err) + } else if err == errInfiniteRetry { + retry-- + continue } else { - return err + break } - - //ch.cur = 0 - - log.Debugf("object part body download interrupted %s, err, %v, retrying attempt %d", - params.URL, err, retry) } - d.incrWritten(n) - log.Debugf("down_%d downloaded chunk", ch.id) - //ch.buf.buffer.wg1.Wait() - //log.Debugf("down_%d downloaded chunk,wg wait passed", ch.id) return err } -func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int64, error) { +var errCancelConcurrency = fmt.Errorf("cancel concurrency") +var errInfiniteRetry = fmt.Errorf("infinite retry") +func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int64, error) { resp, err := d.cfg.HttpClient(d.ctx, params) if err != nil { - return 0, err + if resp == nil { + return 0, err + } + if ch.id == 0 { //第1个任务 有限的重试,超过重试就会结束请求 + switch resp.StatusCode { + default: + return 0, err + case http.StatusTooManyRequests: + case http.StatusBadGateway: + case http.StatusServiceUnavailable: + case http.StatusGatewayTimeout: + } + <-time.After(time.Millisecond * 200) + return 0, &errNeedRetry{err: fmt.Errorf("http request failure,status: %d", resp.StatusCode)} + } + + // 来到这 说明第1个分片下载 连接成功了 + // 后续分片下载出错都当超载处理 + log.Debugf("err chunk_%d, try downloading:%v", ch.id, err) + + d.m.Lock() + isCancelConcurrency := ch.newConcurrency + if d.concurrency > 0 { // 取消剩余的并发任务 + // 用于计算实际的并发数 + d.concurrency = -d.concurrency + isCancelConcurrency = true + } + if isCancelConcurrency { + d.concurrency-- + d.chunkChannel <- *ch + d.m.Unlock() + return 0, errCancelConcurrency + } + d.m.Unlock() + if ch.id != d.readingID { //正在被读取的优先重试 + d.m2.Lock() + defer d.m2.Unlock() + <-time.After(time.Millisecond * 200) + } + return 0, errInfiniteRetry } defer resp.Body.Close() //only check file size on the first task @@ -273,15 +428,15 @@ func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int return 0, err } } - + d.sendChunkTask(true) n, err := utils.CopyWithBuffer(ch.buf, resp.Body) if err != nil { - return n, &errReadingBody{err: err} + return n, &errNeedRetry{err: err} } if n != ch.size { err = fmt.Errorf("chunk download size incorrect, expected=%d, got=%d", ch.size, n) - return n, &errReadingBody{err: err} + return n, &errNeedRetry{err: err} } return n, nil @@ -297,7 +452,7 @@ func (d *downloader) getParamsFromChunk(ch *chunk) *HttpRequestParams { func (d *downloader) checkTotalBytes(resp *http.Response) error { var err error - var totalBytes int64 = math.MinInt64 + totalBytes := int64(-1) contentRange := resp.Header.Get("Content-Range") if len(contentRange) == 0 { // ContentRange is nil when the full file contents is provided, and @@ -329,8 +484,9 @@ func (d *downloader) checkTotalBytes(resp *http.Response) error { err = fmt.Errorf("expect file size=%d unmatch remote report size=%d, need refresh cache", d.params.Size, totalBytes) } if err != nil { - _ = d.interrupt() + // _ = d.interrupt() d.setErr(err) + d.cancel(err) } return err @@ -369,9 +525,7 @@ type chunk struct { buf *Buf id int - // Downloader takes range (start,length), but this chunk is requesting equal/sub range of it. - // To convert the writer to reader eventually, we need to write within the boundary - //boundary http_range.Range + newConcurrency bool } func DefaultHttpRequestFunc(ctx context.Context, params *HttpRequestParams) (*http.Response, error) { @@ -379,7 +533,7 @@ func DefaultHttpRequestFunc(ctx context.Context, params *HttpRequestParams) (*ht res, err := RequestHttp(ctx, "GET", header, params.URL) if err != nil { - return nil, err + return res, err } return res, nil } @@ -392,15 +546,15 @@ type HttpRequestParams struct { //total file size Size int64 } -type errReadingBody struct { +type errNeedRetry struct { err error } -func (e *errReadingBody) Error() string { - return fmt.Sprintf("failed to read part body: %v", e.err) +func (e *errNeedRetry) Error() string { + return e.err.Error() } -func (e *errReadingBody) Unwrap() error { +func (e *errNeedRetry) Unwrap() error { return e.err } @@ -438,9 +592,13 @@ func (mr MultiReadCloser) Read(p []byte) (n int, err error) { } mr.cfg.curBuf = next mr.cfg.rPos++ - //current.Close() return n, nil } + if err == context.Canceled { + if e := context.Cause(mr.cfg.curBuf.ctx); e != nil { + err = e + } + } return n, err } func (mr MultiReadCloser) Close() error { @@ -453,18 +611,16 @@ type Buf struct { ctx context.Context off int rw sync.Mutex - //notify chan struct{} } // NewBuf is a buffer that can have 1 read & 1 write at the same time. // when read is faster write, immediately feed data to read after written -func NewBuf(ctx context.Context, maxSize int, id int) *Buf { +func NewBuf(ctx context.Context, maxSize int) *Buf { d := make([]byte, 0, maxSize) return &Buf{ ctx: ctx, buffer: bytes.NewBuffer(d), size: maxSize, - //notify: make(chan struct{}), } } func (br *Buf) Reset(size int) { @@ -502,8 +658,6 @@ func (br *Buf) Read(p []byte) (n int, err error) { select { case <-br.ctx.Done(): return 0, br.ctx.Err() - //case <-br.notify: - // return 0, nil case <-time.After(time.Millisecond * 200): return 0, nil } @@ -516,13 +670,9 @@ func (br *Buf) Write(p []byte) (n int, err error) { br.rw.Lock() defer br.rw.Unlock() n, err = br.buffer.Write(p) - select { - //case br.notify <- struct{}{}: - default: - } return } func (br *Buf) Close() { - //close(br.notify) + br.buffer.Reset() } diff --git a/internal/net/serve.go b/internal/net/serve.go index e85f61a8..6216cd21 100644 --- a/internal/net/serve.go +++ b/internal/net/serve.go @@ -52,7 +52,8 @@ import ( // // If the caller has set w's ETag header formatted per RFC 7232, section 2.3, // ServeHTTP uses it to handle requests using If-Match, If-None-Match, or If-Range. -func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time.Time, size int64, RangeReaderFunc model.RangeReaderFunc) { +func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time.Time, size int64, RangeReadCloser model.RangeReadCloserIF) { + defer RangeReadCloser.Close() setLastModified(w, modTime) done, rangeReq := checkPreconditions(w, r, modTime) if done { @@ -110,11 +111,19 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time // or unknown file size, ignore the range request. ranges = nil } + + // 使用请求的Context + // 不然从sendContent读不到数据,即使请求断开CopyBuffer也会一直堵塞 + ctx := r.Context() switch { case len(ranges) == 0: - reader, err := RangeReaderFunc(context.Background(), http_range.Range{Length: -1}) + reader, err := RangeReadCloser.RangeRead(ctx, http_range.Range{Length: -1}) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + code = http.StatusRequestedRangeNotSatisfiable + if err == ErrExceedMaxConcurrency { + code = http.StatusTooManyRequests + } + http.Error(w, err.Error(), code) return } sendContent = reader @@ -131,9 +140,13 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time // does not request multiple parts might not support // multipart responses." ra := ranges[0] - sendContent, err = RangeReaderFunc(context.Background(), ra) + sendContent, err = RangeReadCloser.RangeRead(ctx, ra) if err != nil { - http.Error(w, err.Error(), http.StatusRequestedRangeNotSatisfiable) + code = http.StatusRequestedRangeNotSatisfiable + if err == ErrExceedMaxConcurrency { + code = http.StatusTooManyRequests + } + http.Error(w, err.Error(), code) return } sendSize = ra.Length @@ -158,7 +171,7 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time pw.CloseWithError(err) return } - reader, err := RangeReaderFunc(context.Background(), ra) + reader, err := RangeReadCloser.RangeRead(ctx, ra) if err != nil { pw.CloseWithError(err) return @@ -167,14 +180,12 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time pw.CloseWithError(err) return } - //defer reader.Close() } mw.Close() pw.Close() }() } - //defer sendContent.Close() w.Header().Set("Accept-Ranges", "bytes") if w.Header().Get("Content-Encoding") == "" { @@ -190,7 +201,11 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time if written != sendSize { log.Warnf("Maybe size incorrect or reader not giving correct/full data, or connection closed before finish. written bytes: %d ,sendSize:%d, ", written, sendSize) } - http.Error(w, err.Error(), http.StatusInternalServerError) + code = http.StatusInternalServerError + if err == ErrExceedMaxConcurrency { + code = http.StatusTooManyRequests + } + http.Error(w, err.Error(), code) } } } @@ -239,7 +254,7 @@ func RequestHttp(ctx context.Context, httpMethod string, headerOverride http.Hea _ = res.Body.Close() msg := string(all) log.Debugln(msg) - return nil, fmt.Errorf("http request [%s] failure,status: %d response:%s", URL, res.StatusCode, msg) + return res, fmt.Errorf("http request [%s] failure,status: %d response:%s", URL, res.StatusCode, msg) } return res, nil } diff --git a/internal/net/util.go b/internal/net/util.go index 44201859..45301dde 100644 --- a/internal/net/util.go +++ b/internal/net/util.go @@ -2,7 +2,6 @@ package net import ( "fmt" - "github.com/alist-org/alist/v3/pkg/utils" "io" "math" "mime/multipart" @@ -11,6 +10,8 @@ import ( "strings" "time" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/pkg/http_range" log "github.com/sirupsen/logrus" ) diff --git a/internal/offline_download/transmission/client.go b/internal/offline_download/transmission/client.go index 8049afd6..ae136009 100644 --- a/internal/offline_download/transmission/client.go +++ b/internal/offline_download/transmission/client.go @@ -5,7 +5,6 @@ import ( "context" "encoding/base64" "fmt" - "io" "net/http" "net/url" "strconv" @@ -15,6 +14,7 @@ import ( "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/offline_download/tool" "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/utils" "github.com/hekmon/transmissionrpc/v3" "github.com/pkg/errors" log "github.com/sirupsen/logrus" @@ -92,7 +92,7 @@ func (t *Transmission) AddURL(args *tool.AddUrlArgs) (string, error) { buffer := new(bytes.Buffer) encoder := base64.NewEncoder(base64.StdEncoding, buffer) // Stream file to the encoder - if _, err = io.Copy(encoder, resp.Body); err != nil { + if _, err = utils.CopyWithBuffer(encoder, resp.Body); err != nil { return "", errors.Wrap(err, "can't copy file content into the base64 encoder") } // Flush last bytes diff --git a/internal/stream/stream.go b/internal/stream/stream.go index b19eb077..0915ee6b 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -122,7 +122,8 @@ const InMemoryBufMaxSizeBytes = InMemoryBufMaxSize * 1024 * 1024 // also support a peeking RangeRead at very start, but won't buffer more than 10MB data in memory func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { if httpRange.Length == -1 { - httpRange.Length = f.GetSize() + // 参考 internal/net/request.go + httpRange.Length = f.GetSize() - httpRange.Start } if f.peekBuff != nil && httpRange.Start < int64(f.peekBuff.Len()) && httpRange.Start+httpRange.Length-1 < int64(f.peekBuff.Len()) { return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil @@ -210,7 +211,7 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error) // RangeRead is not thread-safe, pls use it in single thread only. func (ss *SeekableStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { if httpRange.Length == -1 { - httpRange.Length = ss.GetSize() + httpRange.Length = ss.GetSize() - httpRange.Start } if ss.mFile != nil { return io.NewSectionReader(ss.mFile, httpRange.Start, httpRange.Length), nil diff --git a/internal/stream/util.go b/internal/stream/util.go index 7d2b7ef7..16854c38 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -6,7 +6,6 @@ import ( "io" "net/http" - "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/net" "github.com/alist-org/alist/v3/pkg/http_range" @@ -17,7 +16,6 @@ func GetRangeReadCloserFromLink(size int64, link *model.Link) (model.RangeReadCl if len(link.URL) == 0 { return nil, fmt.Errorf("can't create RangeReadCloser since URL is empty in link") } - //remoteClosers := utils.EmptyClosers() rangeReaderFunc := func(ctx context.Context, r http_range.Range) (io.ReadCloser, error) { if link.Concurrency != 0 || link.PartSize != 0 { header := net.ProcessHeader(http.Header{}, link.Header) @@ -32,37 +30,29 @@ func GetRangeReadCloserFromLink(size int64, link *model.Link) (model.RangeReadCl HeaderRef: header, } rc, err := down.Download(ctx, req) - if err != nil { - return nil, errs.NewErr(err, "GetReadCloserFromLink failed") - } - return rc, nil + return rc, err } - if len(link.URL) > 0 { - response, err := RequestRangedHttp(ctx, link, r.Start, r.Length) - if err != nil { - if response == nil { - return nil, fmt.Errorf("http request failure, err:%s", err) - } - return nil, fmt.Errorf("http request failure,status: %d err:%s", response.StatusCode, err) + response, err := RequestRangedHttp(ctx, link, r.Start, r.Length) + if err != nil { + if response == nil { + return nil, fmt.Errorf("http request failure, err:%s", err) } - if r.Start == 0 && (r.Length == -1 || r.Length == size) || response.StatusCode == http.StatusPartialContent || - checkContentRange(&response.Header, r.Start) { - return response.Body, nil - } else if response.StatusCode == http.StatusOK { - log.Warnf("remote http server not supporting range request, expect low perfromace!") - readCloser, err := net.GetRangedHttpReader(response.Body, r.Start, r.Length) - if err != nil { - return nil, err - } - return readCloser, nil - - } - + return nil, err + } + if r.Start == 0 && (r.Length == -1 || r.Length == size) || response.StatusCode == http.StatusPartialContent || + checkContentRange(&response.Header, r.Start) { return response.Body, nil + } else if response.StatusCode == http.StatusOK { + log.Warnf("remote http server not supporting range request, expect low perfromace!") + readCloser, err := net.GetRangedHttpReader(response.Body, r.Start, r.Length) + if err != nil { + return nil, err + } + return readCloser, nil } - return nil, errs.NotSupport + return response.Body, nil } resultRangeReadCloser := model.RangeReadCloser{RangeReader: rangeReaderFunc} return &resultRangeReadCloser, nil diff --git a/server/common/proxy.go b/server/common/proxy.go index 10923613..2d828efd 100644 --- a/server/common/proxy.go +++ b/server/common/proxy.go @@ -27,16 +27,11 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. return nil } else if link.RangeReadCloser != nil { attachFileName(w, file) - net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), link.RangeReadCloser.RangeRead) - defer func() { - _ = link.RangeReadCloser.Close() - }() + net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), link.RangeReadCloser) return nil } else if link.Concurrency != 0 || link.PartSize != 0 { attachFileName(w, file) size := file.GetSize() - //var finalClosers model.Closers - finalClosers := utils.EmptyClosers() header := net.ProcessHeader(r.Header, link.Header) rangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { down := net.NewDownloader(func(d *net.Downloader) { @@ -50,16 +45,14 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. HeaderRef: header, } rc, err := down.Download(ctx, req) - finalClosers.Add(rc) return rc, err } - net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), rangeReader) - defer finalClosers.Close() + net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), &model.RangeReadCloser{RangeReader: rangeReader}) return nil } else { //transparent proxy header := net.ProcessHeader(r.Header, link.Header) - res, err := net.RequestHttp(context.Background(), r.Method, header, link.URL) + res, err := net.RequestHttp(r.Context(), r.Method, header, link.URL) if err != nil { return err } @@ -72,7 +65,7 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. if r.Method == http.MethodHead { return nil } - _, err = io.Copy(w, res.Body) + _, err = utils.CopyWithBuffer(w, res.Body) if err != nil { return err } diff --git a/server/handles/archive.go b/server/handles/archive.go index 29dbf3c2..bad99bac 100644 --- a/server/handles/archive.go +++ b/server/handles/archive.go @@ -281,10 +281,11 @@ func ArchiveDown(c *gin.Context) { link, _, err := fs.ArchiveDriverExtract(c, archiveRawPath, model.ArchiveInnerArgs{ ArchiveArgs: model.ArchiveArgs{ LinkArgs: model.LinkArgs{ - IP: c.ClientIP(), - Header: c.Request.Header, - Type: c.Query("type"), - HttpReq: c.Request, + IP: c.ClientIP(), + Header: c.Request.Header, + Type: c.Query("type"), + HttpReq: c.Request, + Redirect: true, }, Password: password, }, diff --git a/server/handles/down.go b/server/handles/down.go index f01c9d66..b2f9a21b 100644 --- a/server/handles/down.go +++ b/server/handles/down.go @@ -31,10 +31,11 @@ func Down(c *gin.Context) { return } else { link, _, err := fs.Link(c, rawPath, model.LinkArgs{ - IP: c.ClientIP(), - Header: c.Request.Header, - Type: c.Query("type"), - HttpReq: c.Request, + IP: c.ClientIP(), + Header: c.Request.Header, + Type: c.Query("type"), + HttpReq: c.Request, + Redirect: true, }) if err != nil { common.ErrorResp(c, err, 500) diff --git a/server/s3/backend.go b/server/s3/backend.go index e0cfd967..bca45008 100644 --- a/server/s3/backend.go +++ b/server/s3/backend.go @@ -6,13 +6,14 @@ import ( "context" "encoding/hex" "fmt" - "github.com/pkg/errors" "io" "path" "strings" "sync" "time" + "github.com/pkg/errors" + "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/fs" "github.com/alist-org/alist/v3/internal/model" @@ -173,15 +174,27 @@ func (b *s3Backend) GetObject(ctx context.Context, bucketName, objectName string if link.RangeReadCloser == nil && link.MFile == nil && len(link.URL) == 0 { return nil, fmt.Errorf("the remote storage driver need to be enhanced to support s3") } - remoteFileSize := file.GetSize() - remoteClosers := utils.EmptyClosers() - rangeReaderFunc := func(ctx context.Context, start, length int64) (io.ReadCloser, error) { + + var rdr io.ReadCloser + length := int64(-1) + start := int64(0) + if rnge != nil { + start, length = rnge.Start, rnge.Length + } + // 参考 server/common/proxy.go + if link.MFile != nil { + _, err := link.MFile.Seek(start, io.SeekStart) + if err != nil { + return nil, err + } + rdr = link.MFile + } else { + remoteFileSize := file.GetSize() if length >= 0 && start+length >= remoteFileSize { length = -1 } rrc := link.RangeReadCloser if len(link.URL) > 0 { - rangedRemoteLink := &model.Link{ URL: link.URL, Header: link.Header, @@ -194,35 +207,12 @@ func (b *s3Backend) GetObject(ctx context.Context, bucketName, objectName string } if rrc != nil { remoteReader, err := rrc.RangeRead(ctx, http_range.Range{Start: start, Length: length}) - remoteClosers.AddClosers(rrc.GetClosers()) if err != nil { return nil, err } - return remoteReader, nil - } - if link.MFile != nil { - _, err := link.MFile.Seek(start, io.SeekStart) - if err != nil { - return nil, err - } - //remoteClosers.Add(remoteLink.MFile) - //keep reuse same MFile and close at last. - remoteClosers.Add(link.MFile) - return io.NopCloser(link.MFile), nil - } - return nil, errs.NotSupport - } - - var rdr io.ReadCloser - if rnge != nil { - rdr, err = rangeReaderFunc(ctx, rnge.Start, rnge.Length) - if err != nil { - return nil, err - } - } else { - rdr, err = rangeReaderFunc(ctx, 0, -1) - if err != nil { - return nil, err + rdr = utils.ReadCloser{Reader: remoteReader, Closer: rrc} + } else { + return nil, errs.NotSupport } } diff --git a/server/webdav/webdav.go b/server/webdav/webdav.go index b84e65b0..6585056b 100644 --- a/server/webdav/webdav.go +++ b/server/webdav/webdav.go @@ -263,7 +263,7 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request) (sta w.Header().Set("Cache-Control", "max-age=0, no-cache, no-store, must-revalidate") http.Redirect(w, r, u, http.StatusFound) } else { - link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{IP: utils.ClientIP(r), Header: r.Header, HttpReq: r}) + link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{IP: utils.ClientIP(r), Header: r.Header, HttpReq: r, Redirect: true}) if err != nil { return http.StatusInternalServerError, err }