mirror of https://github.com/Xhofe/alist
feat(alias): add `DownloadConcurrency` and `DownloadPartSize` option (#7829)
* fix(net): goroutine logic bug (AlistGo/alist#7215) * Fix goroutine logic bug * Fix bug --------- Co-authored-by: hpy hs <hshpy.pengyu@gmail.com> * perf(net): sequential and dynamic concurrency * fix(net): incorrect error return * feat(alias): add `DownloadConcurrency` and `DownloadPartSize` option * feat(net): add `ConcurrencyLimit` * pref(net): create `chunk` on demand * refactor * refactor * fix(net): `r.Closers.Add` has no effect * refactor --------- Co-authored-by: hpy hs <hshpy.pengyu@gmail.com>pull/7807/head^2
parent
bdcf450203
commit
2be0c3d1a0
|
@ -110,6 +110,16 @@ func (d *Alias) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (
|
||||||
for _, dst := range dsts {
|
for _, dst := range dsts {
|
||||||
link, err := d.link(ctx, dst, sub, args)
|
link, err := d.link(ctx, dst, sub, args)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
if !args.Redirect && len(link.URL) > 0 {
|
||||||
|
// 正常情况下 多并发 仅支持返回URL的驱动
|
||||||
|
// alias套娃alias 可以让crypt、mega等驱动(不返回URL的) 支持并发
|
||||||
|
if d.DownloadConcurrency > 0 {
|
||||||
|
link.Concurrency = d.DownloadConcurrency
|
||||||
|
}
|
||||||
|
if d.DownloadPartSize > 0 {
|
||||||
|
link.PartSize = d.DownloadPartSize * utils.KB
|
||||||
|
}
|
||||||
|
}
|
||||||
return link, nil
|
return link, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,8 +9,10 @@ type Addition struct {
|
||||||
// Usually one of two
|
// Usually one of two
|
||||||
// driver.RootPath
|
// driver.RootPath
|
||||||
// define other
|
// define other
|
||||||
Paths string `json:"paths" required:"true" type:"text"`
|
Paths string `json:"paths" required:"true" type:"text"`
|
||||||
ProtectSameName bool `json:"protect_same_name" default:"true" required:"false" help:"Protects same-name files from Delete or Rename"`
|
ProtectSameName bool `json:"protect_same_name" default:"true" required:"false" help:"Protects same-name files from Delete or Rename"`
|
||||||
|
DownloadConcurrency int `json:"download_concurrency" default:"0" required:"false" type:"number" help:"Need to enable proxy"`
|
||||||
|
DownloadPartSize int `json:"download_part_size" default:"0" type:"number" required:"false" help:"Need to enable proxy. Unit: KB"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var config = driver.Config{
|
var config = driver.Config{
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"github.com/alist-org/alist/v3/internal/errs"
|
"github.com/alist-org/alist/v3/internal/errs"
|
||||||
"github.com/alist-org/alist/v3/internal/fs"
|
"github.com/alist-org/alist/v3/internal/fs"
|
||||||
"github.com/alist-org/alist/v3/internal/model"
|
"github.com/alist-org/alist/v3/internal/model"
|
||||||
|
"github.com/alist-org/alist/v3/internal/op"
|
||||||
"github.com/alist-org/alist/v3/internal/sign"
|
"github.com/alist-org/alist/v3/internal/sign"
|
||||||
"github.com/alist-org/alist/v3/pkg/utils"
|
"github.com/alist-org/alist/v3/pkg/utils"
|
||||||
"github.com/alist-org/alist/v3/server/common"
|
"github.com/alist-org/alist/v3/server/common"
|
||||||
|
@ -94,10 +95,15 @@ func (d *Alias) list(ctx context.Context, dst, sub string, args *fs.ListArgs) ([
|
||||||
|
|
||||||
func (d *Alias) link(ctx context.Context, dst, sub string, args model.LinkArgs) (*model.Link, error) {
|
func (d *Alias) link(ctx context.Context, dst, sub string, args model.LinkArgs) (*model.Link, error) {
|
||||||
reqPath := stdpath.Join(dst, sub)
|
reqPath := stdpath.Join(dst, sub)
|
||||||
storage, err := fs.GetStorage(reqPath, &fs.GetStoragesArgs{})
|
// 参考 crypt 驱动
|
||||||
|
storage, reqActualPath, err := op.GetStorageAndActualPath(reqPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if _, ok := storage.(*Alias); !ok && !args.Redirect {
|
||||||
|
link, _, err := op.Link(ctx, storage, reqActualPath, args)
|
||||||
|
return link, err
|
||||||
|
}
|
||||||
_, err = fs.Get(ctx, reqPath, &fs.GetArgs{NoLog: true})
|
_, err = fs.Get(ctx, reqPath, &fs.GetArgs{NoLog: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -114,7 +120,7 @@ func (d *Alias) link(ctx context.Context, dst, sub string, args model.LinkArgs)
|
||||||
}
|
}
|
||||||
return link, nil
|
return link, nil
|
||||||
}
|
}
|
||||||
link, _, err := fs.Link(ctx, reqPath, args)
|
link, _, err := op.Link(ctx, storage, reqActualPath, args)
|
||||||
return link, err
|
return link, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -275,7 +275,6 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (
|
||||||
rrc = converted
|
rrc = converted
|
||||||
}
|
}
|
||||||
if rrc != nil {
|
if rrc != nil {
|
||||||
//remoteRangeReader, err :=
|
|
||||||
remoteReader, err := rrc.RangeRead(ctx, http_range.Range{Start: underlyingOffset, Length: length})
|
remoteReader, err := rrc.RangeRead(ctx, http_range.Range{Start: underlyingOffset, Length: length})
|
||||||
remoteClosers.AddClosers(rrc.GetClosers())
|
remoteClosers.AddClosers(rrc.GetClosers())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -288,10 +287,8 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
//remoteClosers.Add(remoteLink.MFile)
|
// 可以直接返回,读取完也不会调用Close,直到连接断开Close
|
||||||
//keep reuse same MFile and close at last.
|
return remoteLink.MFile, nil
|
||||||
remoteClosers.Add(remoteLink.MFile)
|
|
||||||
return io.NopCloser(remoteLink.MFile), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, errs.NotSupport
|
return nil, errs.NotSupport
|
||||||
|
|
|
@ -5,6 +5,13 @@ import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
stdpath "path"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
"github.com/alist-org/alist/v3/drivers/base"
|
"github.com/alist-org/alist/v3/drivers/base"
|
||||||
"github.com/alist-org/alist/v3/internal/driver"
|
"github.com/alist-org/alist/v3/internal/driver"
|
||||||
"github.com/alist-org/alist/v3/internal/errs"
|
"github.com/alist-org/alist/v3/internal/errs"
|
||||||
|
@ -12,12 +19,6 @@ import (
|
||||||
"github.com/alist-org/alist/v3/pkg/utils"
|
"github.com/alist-org/alist/v3/pkg/utils"
|
||||||
"github.com/go-resty/resty/v2"
|
"github.com/go-resty/resty/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
stdpath "path"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"text/template"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Github struct {
|
type Github struct {
|
||||||
|
@ -656,7 +657,7 @@ func (d *Github) putBlob(ctx context.Context, stream model.FileStreamer, up driv
|
||||||
contentReader, contentWriter := io.Pipe()
|
contentReader, contentWriter := io.Pipe()
|
||||||
go func() {
|
go func() {
|
||||||
encoder := base64.NewEncoder(base64.StdEncoding, contentWriter)
|
encoder := base64.NewEncoder(base64.StdEncoding, contentWriter)
|
||||||
if _, err := io.Copy(encoder, stream); err != nil {
|
if _, err := utils.CopyWithBuffer(encoder, stream); err != nil {
|
||||||
_ = contentWriter.CloseWithError(err)
|
_ = contentWriter.CloseWithError(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,12 +4,17 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/url"
|
||||||
|
"path"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/alist-org/alist/v3/drivers/base"
|
"github.com/alist-org/alist/v3/drivers/base"
|
||||||
"github.com/alist-org/alist/v3/internal/driver"
|
"github.com/alist-org/alist/v3/internal/driver"
|
||||||
"github.com/alist-org/alist/v3/internal/model"
|
"github.com/alist-org/alist/v3/internal/model"
|
||||||
"github.com/alist-org/alist/v3/internal/op"
|
"github.com/alist-org/alist/v3/internal/op"
|
||||||
"github.com/alist-org/alist/v3/pkg/http_range"
|
"github.com/alist-org/alist/v3/pkg/http_range"
|
||||||
"github.com/alist-org/alist/v3/pkg/utils"
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||||
"github.com/aws/aws-sdk-go/aws/session"
|
"github.com/aws/aws-sdk-go/aws/session"
|
||||||
|
@ -19,11 +24,6 @@ import (
|
||||||
pubUserFile "github.com/city404/v6-public-rpc-proto/go/v6/userfile"
|
pubUserFile "github.com/city404/v6-public-rpc-proto/go/v6/userfile"
|
||||||
"github.com/rclone/rclone/lib/readers"
|
"github.com/rclone/rclone/lib/readers"
|
||||||
"github.com/zzzhr1990/go-common-entity/userfile"
|
"github.com/zzzhr1990/go-common-entity/userfile"
|
||||||
"io"
|
|
||||||
"net/url"
|
|
||||||
"path"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type HalalCloud struct {
|
type HalalCloud struct {
|
||||||
|
@ -251,7 +251,6 @@ func (d *HalalCloud) getLink(ctx context.Context, file model.Obj, args model.Lin
|
||||||
|
|
||||||
size := result.FileSize
|
size := result.FileSize
|
||||||
chunks := getChunkSizes(result.Sizes)
|
chunks := getChunkSizes(result.Sizes)
|
||||||
var finalClosers utils.Closers
|
|
||||||
resultRangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
|
resultRangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
|
||||||
length := httpRange.Length
|
length := httpRange.Length
|
||||||
if httpRange.Length >= 0 && httpRange.Start+httpRange.Length >= size {
|
if httpRange.Length >= 0 && httpRange.Start+httpRange.Length >= size {
|
||||||
|
@ -269,7 +268,6 @@ func (d *HalalCloud) getLink(ctx context.Context, file model.Obj, args model.Lin
|
||||||
sha: result.Sha1,
|
sha: result.Sha1,
|
||||||
shaTemp: sha1.New(),
|
shaTemp: sha1.New(),
|
||||||
}
|
}
|
||||||
finalClosers.Add(oo)
|
|
||||||
|
|
||||||
return readers.NewLimitedReadCloser(oo, length), nil
|
return readers.NewLimitedReadCloser(oo, length), nil
|
||||||
}
|
}
|
||||||
|
@ -281,7 +279,7 @@ func (d *HalalCloud) getLink(ctx context.Context, file model.Obj, args model.Lin
|
||||||
duration = time.Until(time.Now().Add(time.Hour))
|
duration = time.Until(time.Now().Add(time.Hour))
|
||||||
}
|
}
|
||||||
|
|
||||||
resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader, Closers: finalClosers}
|
resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader}
|
||||||
return &model.Link{
|
return &model.Link{
|
||||||
RangeReadCloser: resultRangeReadCloser,
|
RangeReadCloser: resultRangeReadCloser,
|
||||||
Expiration: &duration,
|
Expiration: &duration,
|
||||||
|
|
|
@ -84,7 +84,6 @@ func (d *Mega) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*
|
||||||
//}
|
//}
|
||||||
|
|
||||||
size := file.GetSize()
|
size := file.GetSize()
|
||||||
var finalClosers utils.Closers
|
|
||||||
resultRangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
|
resultRangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
|
||||||
length := httpRange.Length
|
length := httpRange.Length
|
||||||
if httpRange.Length >= 0 && httpRange.Start+httpRange.Length >= size {
|
if httpRange.Length >= 0 && httpRange.Start+httpRange.Length >= size {
|
||||||
|
@ -103,11 +102,10 @@ func (d *Mega) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*
|
||||||
d: down,
|
d: down,
|
||||||
skip: httpRange.Start,
|
skip: httpRange.Start,
|
||||||
}
|
}
|
||||||
finalClosers.Add(oo)
|
|
||||||
|
|
||||||
return readers.NewLimitedReadCloser(oo, length), nil
|
return readers.NewLimitedReadCloser(oo, length), nil
|
||||||
}
|
}
|
||||||
resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader, Closers: finalClosers}
|
resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader}
|
||||||
resultLink := &model.Link{
|
resultLink := &model.Link{
|
||||||
RangeReadCloser: resultRangeReadCloser,
|
RangeReadCloser: resultRangeReadCloser,
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,7 +64,6 @@ func (lrc *LyricObj) getLyricLink() *model.Link {
|
||||||
sr := io.NewSectionReader(reader, httpRange.Start, httpRange.Length)
|
sr := io.NewSectionReader(reader, httpRange.Start, httpRange.Length)
|
||||||
return io.NopCloser(sr), nil
|
return io.NopCloser(sr), nil
|
||||||
},
|
},
|
||||||
Closers: utils.EmptyClosers(),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,7 +47,7 @@ func (u *uploader) init(stream model.FileStreamer) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
h := md5.New()
|
h := md5.New()
|
||||||
io.Copy(h, stream)
|
utils.CopyWithBuffer(h, stream)
|
||||||
u.md5 = hex.EncodeToString(h.Sum(nil))
|
u.md5 = hex.EncodeToString(h.Sum(nil))
|
||||||
_, err := u.file.Seek(0, io.SeekStart)
|
_, err := u.file.Seek(0, io.SeekStart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -300,9 +300,7 @@ func (d *Quqi) linkFromCDN(id string) (*model.Link, error) {
|
||||||
bufferReader := bufio.NewReader(decryptReader)
|
bufferReader := bufio.NewReader(decryptReader)
|
||||||
bufferReader.Discard(int(decryptedOffset))
|
bufferReader.Discard(int(decryptedOffset))
|
||||||
|
|
||||||
return utils.NewReadCloser(bufferReader, func() error {
|
return io.NopCloser(bufferReader), nil
|
||||||
return nil
|
|
||||||
}), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &model.Link{
|
return &model.Link{
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"github.com/alist-org/alist/v3/cmd/flags"
|
"github.com/alist-org/alist/v3/cmd/flags"
|
||||||
"github.com/alist-org/alist/v3/drivers/base"
|
"github.com/alist-org/alist/v3/drivers/base"
|
||||||
"github.com/alist-org/alist/v3/internal/conf"
|
"github.com/alist-org/alist/v3/internal/conf"
|
||||||
|
"github.com/alist-org/alist/v3/internal/net"
|
||||||
"github.com/alist-org/alist/v3/pkg/utils"
|
"github.com/alist-org/alist/v3/pkg/utils"
|
||||||
"github.com/caarlos0/env/v9"
|
"github.com/caarlos0/env/v9"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
@ -63,6 +64,9 @@ func InitConfig() {
|
||||||
log.Fatalf("update config struct error: %+v", err)
|
log.Fatalf("update config struct error: %+v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if conf.Conf.MaxConcurrency > 0 {
|
||||||
|
net.DefaultConcurrencyLimit = &net.ConcurrencyLimit{Limit: conf.Conf.MaxConcurrency}
|
||||||
|
}
|
||||||
if !conf.Conf.Force {
|
if !conf.Conf.Force {
|
||||||
confFromEnv()
|
confFromEnv()
|
||||||
}
|
}
|
||||||
|
|
|
@ -106,6 +106,7 @@ type Config struct {
|
||||||
Log LogConfig `json:"log"`
|
Log LogConfig `json:"log"`
|
||||||
DelayedStart int `json:"delayed_start" env:"DELAYED_START"`
|
DelayedStart int `json:"delayed_start" env:"DELAYED_START"`
|
||||||
MaxConnections int `json:"max_connections" env:"MAX_CONNECTIONS"`
|
MaxConnections int `json:"max_connections" env:"MAX_CONNECTIONS"`
|
||||||
|
MaxConcurrency int `json:"max_concurrency" env:"MAX_CONCURRENCY"`
|
||||||
TlsInsecureSkipVerify bool `json:"tls_insecure_skip_verify" env:"TLS_INSECURE_SKIP_VERIFY"`
|
TlsInsecureSkipVerify bool `json:"tls_insecure_skip_verify" env:"TLS_INSECURE_SKIP_VERIFY"`
|
||||||
Tasks TasksConfig `json:"tasks" envPrefix:"TASKS_"`
|
Tasks TasksConfig `json:"tasks" envPrefix:"TASKS_"`
|
||||||
Cors Cors `json:"cors" envPrefix:"CORS_"`
|
Cors Cors `json:"cors" envPrefix:"CORS_"`
|
||||||
|
@ -151,6 +152,7 @@ func DefaultConfig() *Config {
|
||||||
MaxAge: 28,
|
MaxAge: 28,
|
||||||
},
|
},
|
||||||
MaxConnections: 0,
|
MaxConnections: 0,
|
||||||
|
MaxConcurrency: 64,
|
||||||
TlsInsecureSkipVerify: true,
|
TlsInsecureSkipVerify: true,
|
||||||
Tasks: TasksConfig{
|
Tasks: TasksConfig{
|
||||||
Download: TaskConfig{
|
Download: TaskConfig{
|
||||||
|
|
|
@ -17,10 +17,11 @@ type ListArgs struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type LinkArgs struct {
|
type LinkArgs struct {
|
||||||
IP string
|
IP string
|
||||||
Header http.Header
|
Header http.Header
|
||||||
Type string
|
Type string
|
||||||
HttpReq *http.Request
|
HttpReq *http.Request
|
||||||
|
Redirect bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type Link struct {
|
type Link struct {
|
||||||
|
@ -87,7 +88,7 @@ type RangeReadCloser struct {
|
||||||
utils.Closers
|
utils.Closers
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r RangeReadCloser) RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
|
func (r *RangeReadCloser) RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
|
||||||
rc, err := r.RangeReader(ctx, httpRange)
|
rc, err := r.RangeReader(ctx, httpRange)
|
||||||
r.Closers.Add(rc)
|
r.Closers.Add(rc)
|
||||||
return rc, err
|
return rc, err
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -21,7 +20,7 @@ import (
|
||||||
|
|
||||||
// DefaultDownloadPartSize is the default range of bytes to get at a time when
|
// DefaultDownloadPartSize is the default range of bytes to get at a time when
|
||||||
// using Download().
|
// using Download().
|
||||||
const DefaultDownloadPartSize = 1024 * 1024 * 10
|
const DefaultDownloadPartSize = utils.MB * 10
|
||||||
|
|
||||||
// DefaultDownloadConcurrency is the default number of goroutines to spin up
|
// DefaultDownloadConcurrency is the default number of goroutines to spin up
|
||||||
// when using Download().
|
// when using Download().
|
||||||
|
@ -30,6 +29,8 @@ const DefaultDownloadConcurrency = 2
|
||||||
// DefaultPartBodyMaxRetries is the default number of retries to make when a part fails to download.
|
// DefaultPartBodyMaxRetries is the default number of retries to make when a part fails to download.
|
||||||
const DefaultPartBodyMaxRetries = 3
|
const DefaultPartBodyMaxRetries = 3
|
||||||
|
|
||||||
|
var DefaultConcurrencyLimit *ConcurrencyLimit
|
||||||
|
|
||||||
type Downloader struct {
|
type Downloader struct {
|
||||||
PartSize int
|
PartSize int
|
||||||
|
|
||||||
|
@ -44,15 +45,15 @@ type Downloader struct {
|
||||||
|
|
||||||
//RequestParam HttpRequestParams
|
//RequestParam HttpRequestParams
|
||||||
HttpClient HttpRequestFunc
|
HttpClient HttpRequestFunc
|
||||||
|
|
||||||
|
*ConcurrencyLimit
|
||||||
}
|
}
|
||||||
type HttpRequestFunc func(ctx context.Context, params *HttpRequestParams) (*http.Response, error)
|
type HttpRequestFunc func(ctx context.Context, params *HttpRequestParams) (*http.Response, error)
|
||||||
|
|
||||||
func NewDownloader(options ...func(*Downloader)) *Downloader {
|
func NewDownloader(options ...func(*Downloader)) *Downloader {
|
||||||
d := &Downloader{
|
d := &Downloader{ //允许不设置的选项
|
||||||
HttpClient: DefaultHttpRequestFunc,
|
|
||||||
PartSize: DefaultDownloadPartSize,
|
|
||||||
PartBodyMaxRetries: DefaultPartBodyMaxRetries,
|
PartBodyMaxRetries: DefaultPartBodyMaxRetries,
|
||||||
Concurrency: DefaultDownloadConcurrency,
|
ConcurrencyLimit: DefaultConcurrencyLimit,
|
||||||
}
|
}
|
||||||
for _, option := range options {
|
for _, option := range options {
|
||||||
option(d)
|
option(d)
|
||||||
|
@ -74,16 +75,16 @@ func (d Downloader) Download(ctx context.Context, p *HttpRequestParams) (readClo
|
||||||
impl := downloader{params: &finalP, cfg: d, ctx: ctx}
|
impl := downloader{params: &finalP, cfg: d, ctx: ctx}
|
||||||
|
|
||||||
// Ensures we don't need nil checks later on
|
// Ensures we don't need nil checks later on
|
||||||
|
// 必需的选项
|
||||||
impl.partBodyMaxRetries = d.PartBodyMaxRetries
|
|
||||||
|
|
||||||
if impl.cfg.Concurrency == 0 {
|
if impl.cfg.Concurrency == 0 {
|
||||||
impl.cfg.Concurrency = DefaultDownloadConcurrency
|
impl.cfg.Concurrency = DefaultDownloadConcurrency
|
||||||
}
|
}
|
||||||
|
|
||||||
if impl.cfg.PartSize == 0 {
|
if impl.cfg.PartSize == 0 {
|
||||||
impl.cfg.PartSize = DefaultDownloadPartSize
|
impl.cfg.PartSize = DefaultDownloadPartSize
|
||||||
}
|
}
|
||||||
|
if impl.cfg.HttpClient == nil {
|
||||||
|
impl.cfg.HttpClient = DefaultHttpRequestFunc
|
||||||
|
}
|
||||||
|
|
||||||
return impl.download()
|
return impl.download()
|
||||||
}
|
}
|
||||||
|
@ -91,7 +92,7 @@ func (d Downloader) Download(ctx context.Context, p *HttpRequestParams) (readClo
|
||||||
// downloader is the implementation structure used internally by Downloader.
|
// downloader is the implementation structure used internally by Downloader.
|
||||||
type downloader struct {
|
type downloader struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelCauseFunc
|
||||||
cfg Downloader
|
cfg Downloader
|
||||||
|
|
||||||
params *HttpRequestParams //http request params
|
params *HttpRequestParams //http request params
|
||||||
|
@ -101,38 +102,78 @@ type downloader struct {
|
||||||
m sync.Mutex
|
m sync.Mutex
|
||||||
|
|
||||||
nextChunk int //next chunk id
|
nextChunk int //next chunk id
|
||||||
chunks []chunk
|
|
||||||
bufs []*Buf
|
bufs []*Buf
|
||||||
//totalBytes int64
|
written int64 //total bytes of file downloaded from remote
|
||||||
written int64 //total bytes of file downloaded from remote
|
err error
|
||||||
err error
|
|
||||||
|
|
||||||
partBodyMaxRetries int
|
concurrency int //剩余的并发数,递减。到0时停止并发
|
||||||
|
maxPart int //有多少个分片
|
||||||
|
pos int64
|
||||||
|
maxPos int64
|
||||||
|
m2 sync.Mutex
|
||||||
|
readingID int // 正在被读取的id
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConcurrencyLimit struct {
|
||||||
|
_m sync.Mutex
|
||||||
|
Limit int // 需要大于0
|
||||||
|
}
|
||||||
|
|
||||||
|
var ErrExceedMaxConcurrency = fmt.Errorf("ExceedMaxConcurrency")
|
||||||
|
|
||||||
|
func (l *ConcurrencyLimit) sub() error {
|
||||||
|
l._m.Lock()
|
||||||
|
defer l._m.Unlock()
|
||||||
|
if l.Limit-1 < 0 {
|
||||||
|
return ErrExceedMaxConcurrency
|
||||||
|
}
|
||||||
|
l.Limit--
|
||||||
|
// log.Debugf("ConcurrencyLimit.sub: %d", l.Limit)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (l *ConcurrencyLimit) add() {
|
||||||
|
l._m.Lock()
|
||||||
|
defer l._m.Unlock()
|
||||||
|
l.Limit++
|
||||||
|
// log.Debugf("ConcurrencyLimit.add: %d", l.Limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检测是否超过限制
|
||||||
|
func (d *downloader) concurrencyCheck() error {
|
||||||
|
if d.cfg.ConcurrencyLimit != nil {
|
||||||
|
return d.cfg.ConcurrencyLimit.sub()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (d *downloader) concurrencyFinish() {
|
||||||
|
if d.cfg.ConcurrencyLimit != nil {
|
||||||
|
d.cfg.ConcurrencyLimit.add()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// download performs the implementation of the object download across ranged GETs.
|
// download performs the implementation of the object download across ranged GETs.
|
||||||
func (d *downloader) download() (io.ReadCloser, error) {
|
func (d *downloader) download() (io.ReadCloser, error) {
|
||||||
d.ctx, d.cancel = context.WithCancel(d.ctx)
|
if err := d.concurrencyCheck(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
d.ctx, d.cancel = context.WithCancelCause(d.ctx)
|
||||||
|
|
||||||
pos := d.params.Range.Start
|
maxPart := int(d.params.Range.Length / int64(d.cfg.PartSize))
|
||||||
maxPos := d.params.Range.Start + d.params.Range.Length
|
if d.params.Range.Length%int64(d.cfg.PartSize) > 0 {
|
||||||
id := 0
|
maxPart++
|
||||||
for pos < maxPos {
|
|
||||||
finalSize := int64(d.cfg.PartSize)
|
|
||||||
//check boundary
|
|
||||||
if pos+finalSize > maxPos {
|
|
||||||
finalSize = maxPos - pos
|
|
||||||
}
|
|
||||||
c := chunk{start: pos, size: finalSize, id: id}
|
|
||||||
d.chunks = append(d.chunks, c)
|
|
||||||
pos += finalSize
|
|
||||||
id++
|
|
||||||
}
|
}
|
||||||
if len(d.chunks) < d.cfg.Concurrency {
|
if maxPart < d.cfg.Concurrency {
|
||||||
d.cfg.Concurrency = len(d.chunks)
|
d.cfg.Concurrency = maxPart
|
||||||
}
|
}
|
||||||
|
log.Debugf("cfgConcurrency:%d", d.cfg.Concurrency)
|
||||||
|
|
||||||
if d.cfg.Concurrency == 1 {
|
if d.cfg.Concurrency == 1 {
|
||||||
|
if d.cfg.ConcurrencyLimit != nil {
|
||||||
|
go func() {
|
||||||
|
<-d.ctx.Done()
|
||||||
|
d.concurrencyFinish()
|
||||||
|
}()
|
||||||
|
}
|
||||||
resp, err := d.cfg.HttpClient(d.ctx, d.params)
|
resp, err := d.cfg.HttpClient(d.ctx, d.params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -143,61 +184,114 @@ func (d *downloader) download() (io.ReadCloser, error) {
|
||||||
// workers
|
// workers
|
||||||
d.chunkChannel = make(chan chunk, d.cfg.Concurrency)
|
d.chunkChannel = make(chan chunk, d.cfg.Concurrency)
|
||||||
|
|
||||||
for i := 0; i < d.cfg.Concurrency; i++ {
|
d.maxPart = maxPart
|
||||||
buf := NewBuf(d.ctx, d.cfg.PartSize, i)
|
d.pos = d.params.Range.Start
|
||||||
d.bufs = append(d.bufs, buf)
|
d.maxPos = d.params.Range.Start + d.params.Range.Length
|
||||||
go d.downloadPart()
|
d.concurrency = d.cfg.Concurrency
|
||||||
}
|
d.sendChunkTask(true)
|
||||||
// initial tasks
|
|
||||||
for i := 0; i < d.cfg.Concurrency; i++ {
|
|
||||||
d.sendChunkTask()
|
|
||||||
}
|
|
||||||
|
|
||||||
var rc io.ReadCloser = NewMultiReadCloser(d.chunks[0].buf, d.interrupt, d.finishBuf)
|
var rc io.ReadCloser = NewMultiReadCloser(d.bufs[0], d.interrupt, d.finishBuf)
|
||||||
|
|
||||||
// Return error
|
// Return error
|
||||||
return rc, d.err
|
return rc, d.err
|
||||||
}
|
}
|
||||||
func (d *downloader) sendChunkTask() *chunk {
|
|
||||||
ch := &d.chunks[d.nextChunk]
|
func (d *downloader) sendChunkTask(newConcurrency bool) error {
|
||||||
ch.buf = d.getBuf(d.nextChunk)
|
d.m.Lock()
|
||||||
ch.buf.Reset(int(ch.size))
|
defer d.m.Unlock()
|
||||||
d.chunkChannel <- *ch
|
isNewBuf := d.concurrency > 0
|
||||||
d.nextChunk++
|
if newConcurrency {
|
||||||
return ch
|
if d.concurrency <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if d.nextChunk > 0 { // 第一个不检查,因为已经检查过了
|
||||||
|
if err := d.concurrencyCheck(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d.concurrency--
|
||||||
|
go d.downloadPart()
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf *Buf
|
||||||
|
if isNewBuf {
|
||||||
|
buf = NewBuf(d.ctx, d.cfg.PartSize)
|
||||||
|
d.bufs = append(d.bufs, buf)
|
||||||
|
} else {
|
||||||
|
buf = d.getBuf(d.nextChunk)
|
||||||
|
}
|
||||||
|
|
||||||
|
if d.pos < d.maxPos {
|
||||||
|
finalSize := int64(d.cfg.PartSize)
|
||||||
|
switch d.nextChunk {
|
||||||
|
case 0:
|
||||||
|
// 最小分片在前面有助视频播放?
|
||||||
|
firstSize := d.params.Range.Length % finalSize
|
||||||
|
if firstSize > 0 {
|
||||||
|
minSize := finalSize / 2
|
||||||
|
if firstSize < minSize { // 最小分片太小就调整到一半
|
||||||
|
finalSize = minSize
|
||||||
|
} else {
|
||||||
|
finalSize = firstSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case 1:
|
||||||
|
firstSize := d.params.Range.Length % finalSize
|
||||||
|
minSize := finalSize / 2
|
||||||
|
if firstSize > 0 && firstSize < minSize {
|
||||||
|
finalSize += firstSize - minSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
buf.Reset(int(finalSize))
|
||||||
|
ch := chunk{
|
||||||
|
start: d.pos,
|
||||||
|
size: finalSize,
|
||||||
|
id: d.nextChunk,
|
||||||
|
buf: buf,
|
||||||
|
}
|
||||||
|
ch.newConcurrency = newConcurrency
|
||||||
|
d.pos += finalSize
|
||||||
|
d.nextChunk++
|
||||||
|
d.chunkChannel <- ch
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// when the final reader Close, we interrupt
|
// when the final reader Close, we interrupt
|
||||||
func (d *downloader) interrupt() error {
|
func (d *downloader) interrupt() error {
|
||||||
|
|
||||||
d.cancel()
|
|
||||||
if d.written != d.params.Range.Length {
|
if d.written != d.params.Range.Length {
|
||||||
log.Debugf("Downloader interrupt before finish")
|
log.Debugf("Downloader interrupt before finish")
|
||||||
if d.getErr() == nil {
|
if d.getErr() == nil {
|
||||||
d.setErr(fmt.Errorf("interrupted"))
|
d.setErr(fmt.Errorf("interrupted"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
d.cancel(d.err)
|
||||||
defer func() {
|
defer func() {
|
||||||
close(d.chunkChannel)
|
close(d.chunkChannel)
|
||||||
for _, buf := range d.bufs {
|
for _, buf := range d.bufs {
|
||||||
buf.Close()
|
buf.Close()
|
||||||
}
|
}
|
||||||
|
if d.concurrency > 0 {
|
||||||
|
d.concurrency = -d.concurrency
|
||||||
|
}
|
||||||
|
log.Debugf("maxConcurrency:%d", d.cfg.Concurrency+d.concurrency)
|
||||||
}()
|
}()
|
||||||
return d.err
|
return d.err
|
||||||
}
|
}
|
||||||
func (d *downloader) getBuf(id int) (b *Buf) {
|
func (d *downloader) getBuf(id int) (b *Buf) {
|
||||||
|
return d.bufs[id%len(d.bufs)]
|
||||||
return d.bufs[id%d.cfg.Concurrency]
|
|
||||||
}
|
}
|
||||||
func (d *downloader) finishBuf(id int) (isLast bool, buf *Buf) {
|
func (d *downloader) finishBuf(id int) (isLast bool, nextBuf *Buf) {
|
||||||
if id >= len(d.chunks)-1 {
|
id++
|
||||||
|
if id >= d.maxPart {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
if d.nextChunk > id+1 {
|
|
||||||
return false, d.getBuf(id + 1)
|
d.sendChunkTask(false)
|
||||||
}
|
|
||||||
ch := d.sendChunkTask()
|
d.readingID = id
|
||||||
return false, ch.buf
|
return false, d.getBuf(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// downloadPart is an individual goroutine worker reading from the ch channel
|
// downloadPart is an individual goroutine worker reading from the ch channel
|
||||||
|
@ -212,58 +306,119 @@ func (d *downloader) downloadPart() {
|
||||||
if d.getErr() != nil {
|
if d.getErr() != nil {
|
||||||
// Drain the channel if there is an error, to prevent deadlocking
|
// Drain the channel if there is an error, to prevent deadlocking
|
||||||
// of download producer.
|
// of download producer.
|
||||||
continue
|
break
|
||||||
}
|
}
|
||||||
log.Debugf("downloadPart tried to get chunk")
|
|
||||||
if err := d.downloadChunk(&c); err != nil {
|
if err := d.downloadChunk(&c); err != nil {
|
||||||
|
if err == errCancelConcurrency {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err == context.Canceled {
|
||||||
|
if e := context.Cause(d.ctx); e != nil {
|
||||||
|
err = e
|
||||||
|
}
|
||||||
|
}
|
||||||
d.setErr(err)
|
d.setErr(err)
|
||||||
|
d.cancel(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
d.concurrencyFinish()
|
||||||
}
|
}
|
||||||
|
|
||||||
// downloadChunk downloads the chunk
|
// downloadChunk downloads the chunk
|
||||||
func (d *downloader) downloadChunk(ch *chunk) error {
|
func (d *downloader) downloadChunk(ch *chunk) error {
|
||||||
log.Debugf("start new chunk %+v buffer_id =%d", ch, ch.id)
|
log.Debugf("start chunk_%d, %+v", ch.id, ch)
|
||||||
|
params := d.getParamsFromChunk(ch)
|
||||||
var n int64
|
var n int64
|
||||||
var err error
|
var err error
|
||||||
params := d.getParamsFromChunk(ch)
|
for retry := 0; retry <= d.cfg.PartBodyMaxRetries; retry++ {
|
||||||
for retry := 0; retry <= d.partBodyMaxRetries; retry++ {
|
|
||||||
if d.getErr() != nil {
|
if d.getErr() != nil {
|
||||||
return d.getErr()
|
return nil
|
||||||
}
|
}
|
||||||
n, err = d.tryDownloadChunk(params, ch)
|
n, err = d.tryDownloadChunk(params, ch)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
d.incrWritten(n)
|
||||||
|
log.Debugf("chunk_%d downloaded", ch.id)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
// Check if the returned error is an errReadingBody.
|
if d.getErr() != nil {
|
||||||
// If err is errReadingBody this indicates that an error
|
return nil
|
||||||
// occurred while copying the http response body.
|
}
|
||||||
|
if utils.IsCanceled(d.ctx) {
|
||||||
|
return d.ctx.Err()
|
||||||
|
}
|
||||||
|
// Check if the returned error is an errNeedRetry.
|
||||||
// If this occurs we unwrap the err to set the underlying error
|
// If this occurs we unwrap the err to set the underlying error
|
||||||
// and attempt any remaining retries.
|
// and attempt any remaining retries.
|
||||||
if bodyErr, ok := err.(*errReadingBody); ok {
|
if e, ok := err.(*errNeedRetry); ok {
|
||||||
err = bodyErr.Unwrap()
|
err = e.Unwrap()
|
||||||
|
if n > 0 {
|
||||||
|
// 测试:下载时 断开 alist向云盘发起的下载连接
|
||||||
|
// 校验:下载完后校验文件哈希值 一致
|
||||||
|
d.incrWritten(n)
|
||||||
|
ch.start += n
|
||||||
|
ch.size -= n
|
||||||
|
params.Range.Start = ch.start
|
||||||
|
params.Range.Length = ch.size
|
||||||
|
}
|
||||||
|
log.Warnf("err chunk_%d, object part download error %s, retrying attempt %d. %v",
|
||||||
|
ch.id, params.URL, retry, err)
|
||||||
|
} else if err == errInfiniteRetry {
|
||||||
|
retry--
|
||||||
|
continue
|
||||||
} else {
|
} else {
|
||||||
return err
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
//ch.cur = 0
|
|
||||||
|
|
||||||
log.Debugf("object part body download interrupted %s, err, %v, retrying attempt %d",
|
|
||||||
params.URL, err, retry)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
d.incrWritten(n)
|
|
||||||
log.Debugf("down_%d downloaded chunk", ch.id)
|
|
||||||
//ch.buf.buffer.wg1.Wait()
|
|
||||||
//log.Debugf("down_%d downloaded chunk,wg wait passed", ch.id)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int64, error) {
|
var errCancelConcurrency = fmt.Errorf("cancel concurrency")
|
||||||
|
var errInfiniteRetry = fmt.Errorf("infinite retry")
|
||||||
|
|
||||||
|
func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int64, error) {
|
||||||
resp, err := d.cfg.HttpClient(d.ctx, params)
|
resp, err := d.cfg.HttpClient(d.ctx, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
if resp == nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if ch.id == 0 { //第1个任务 有限的重试,超过重试就会结束请求
|
||||||
|
switch resp.StatusCode {
|
||||||
|
default:
|
||||||
|
return 0, err
|
||||||
|
case http.StatusTooManyRequests:
|
||||||
|
case http.StatusBadGateway:
|
||||||
|
case http.StatusServiceUnavailable:
|
||||||
|
case http.StatusGatewayTimeout:
|
||||||
|
}
|
||||||
|
<-time.After(time.Millisecond * 200)
|
||||||
|
return 0, &errNeedRetry{err: fmt.Errorf("http request failure,status: %d", resp.StatusCode)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 来到这 说明第1个分片下载 连接成功了
|
||||||
|
// 后续分片下载出错都当超载处理
|
||||||
|
log.Debugf("err chunk_%d, try downloading:%v", ch.id, err)
|
||||||
|
|
||||||
|
d.m.Lock()
|
||||||
|
isCancelConcurrency := ch.newConcurrency
|
||||||
|
if d.concurrency > 0 { // 取消剩余的并发任务
|
||||||
|
// 用于计算实际的并发数
|
||||||
|
d.concurrency = -d.concurrency
|
||||||
|
isCancelConcurrency = true
|
||||||
|
}
|
||||||
|
if isCancelConcurrency {
|
||||||
|
d.concurrency--
|
||||||
|
d.chunkChannel <- *ch
|
||||||
|
d.m.Unlock()
|
||||||
|
return 0, errCancelConcurrency
|
||||||
|
}
|
||||||
|
d.m.Unlock()
|
||||||
|
if ch.id != d.readingID { //正在被读取的优先重试
|
||||||
|
d.m2.Lock()
|
||||||
|
defer d.m2.Unlock()
|
||||||
|
<-time.After(time.Millisecond * 200)
|
||||||
|
}
|
||||||
|
return 0, errInfiniteRetry
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
//only check file size on the first task
|
//only check file size on the first task
|
||||||
|
@ -273,15 +428,15 @@ func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
d.sendChunkTask(true)
|
||||||
n, err := utils.CopyWithBuffer(ch.buf, resp.Body)
|
n, err := utils.CopyWithBuffer(ch.buf, resp.Body)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return n, &errReadingBody{err: err}
|
return n, &errNeedRetry{err: err}
|
||||||
}
|
}
|
||||||
if n != ch.size {
|
if n != ch.size {
|
||||||
err = fmt.Errorf("chunk download size incorrect, expected=%d, got=%d", ch.size, n)
|
err = fmt.Errorf("chunk download size incorrect, expected=%d, got=%d", ch.size, n)
|
||||||
return n, &errReadingBody{err: err}
|
return n, &errNeedRetry{err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
return n, nil
|
return n, nil
|
||||||
|
@ -297,7 +452,7 @@ func (d *downloader) getParamsFromChunk(ch *chunk) *HttpRequestParams {
|
||||||
|
|
||||||
func (d *downloader) checkTotalBytes(resp *http.Response) error {
|
func (d *downloader) checkTotalBytes(resp *http.Response) error {
|
||||||
var err error
|
var err error
|
||||||
var totalBytes int64 = math.MinInt64
|
totalBytes := int64(-1)
|
||||||
contentRange := resp.Header.Get("Content-Range")
|
contentRange := resp.Header.Get("Content-Range")
|
||||||
if len(contentRange) == 0 {
|
if len(contentRange) == 0 {
|
||||||
// ContentRange is nil when the full file contents is provided, and
|
// ContentRange is nil when the full file contents is provided, and
|
||||||
|
@ -329,8 +484,9 @@ func (d *downloader) checkTotalBytes(resp *http.Response) error {
|
||||||
err = fmt.Errorf("expect file size=%d unmatch remote report size=%d, need refresh cache", d.params.Size, totalBytes)
|
err = fmt.Errorf("expect file size=%d unmatch remote report size=%d, need refresh cache", d.params.Size, totalBytes)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = d.interrupt()
|
// _ = d.interrupt()
|
||||||
d.setErr(err)
|
d.setErr(err)
|
||||||
|
d.cancel(err)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
|
||||||
|
@ -369,9 +525,7 @@ type chunk struct {
|
||||||
buf *Buf
|
buf *Buf
|
||||||
id int
|
id int
|
||||||
|
|
||||||
// Downloader takes range (start,length), but this chunk is requesting equal/sub range of it.
|
newConcurrency bool
|
||||||
// To convert the writer to reader eventually, we need to write within the boundary
|
|
||||||
//boundary http_range.Range
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func DefaultHttpRequestFunc(ctx context.Context, params *HttpRequestParams) (*http.Response, error) {
|
func DefaultHttpRequestFunc(ctx context.Context, params *HttpRequestParams) (*http.Response, error) {
|
||||||
|
@ -379,7 +533,7 @@ func DefaultHttpRequestFunc(ctx context.Context, params *HttpRequestParams) (*ht
|
||||||
|
|
||||||
res, err := RequestHttp(ctx, "GET", header, params.URL)
|
res, err := RequestHttp(ctx, "GET", header, params.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return res, err
|
||||||
}
|
}
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
@ -392,15 +546,15 @@ type HttpRequestParams struct {
|
||||||
//total file size
|
//total file size
|
||||||
Size int64
|
Size int64
|
||||||
}
|
}
|
||||||
type errReadingBody struct {
|
type errNeedRetry struct {
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *errReadingBody) Error() string {
|
func (e *errNeedRetry) Error() string {
|
||||||
return fmt.Sprintf("failed to read part body: %v", e.err)
|
return e.err.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *errReadingBody) Unwrap() error {
|
func (e *errNeedRetry) Unwrap() error {
|
||||||
return e.err
|
return e.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -438,9 +592,13 @@ func (mr MultiReadCloser) Read(p []byte) (n int, err error) {
|
||||||
}
|
}
|
||||||
mr.cfg.curBuf = next
|
mr.cfg.curBuf = next
|
||||||
mr.cfg.rPos++
|
mr.cfg.rPos++
|
||||||
//current.Close()
|
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
if err == context.Canceled {
|
||||||
|
if e := context.Cause(mr.cfg.curBuf.ctx); e != nil {
|
||||||
|
err = e
|
||||||
|
}
|
||||||
|
}
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
func (mr MultiReadCloser) Close() error {
|
func (mr MultiReadCloser) Close() error {
|
||||||
|
@ -453,18 +611,16 @@ type Buf struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
off int
|
off int
|
||||||
rw sync.Mutex
|
rw sync.Mutex
|
||||||
//notify chan struct{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBuf is a buffer that can have 1 read & 1 write at the same time.
|
// NewBuf is a buffer that can have 1 read & 1 write at the same time.
|
||||||
// when read is faster write, immediately feed data to read after written
|
// when read is faster write, immediately feed data to read after written
|
||||||
func NewBuf(ctx context.Context, maxSize int, id int) *Buf {
|
func NewBuf(ctx context.Context, maxSize int) *Buf {
|
||||||
d := make([]byte, 0, maxSize)
|
d := make([]byte, 0, maxSize)
|
||||||
return &Buf{
|
return &Buf{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
buffer: bytes.NewBuffer(d),
|
buffer: bytes.NewBuffer(d),
|
||||||
size: maxSize,
|
size: maxSize,
|
||||||
//notify: make(chan struct{}),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func (br *Buf) Reset(size int) {
|
func (br *Buf) Reset(size int) {
|
||||||
|
@ -502,8 +658,6 @@ func (br *Buf) Read(p []byte) (n int, err error) {
|
||||||
select {
|
select {
|
||||||
case <-br.ctx.Done():
|
case <-br.ctx.Done():
|
||||||
return 0, br.ctx.Err()
|
return 0, br.ctx.Err()
|
||||||
//case <-br.notify:
|
|
||||||
// return 0, nil
|
|
||||||
case <-time.After(time.Millisecond * 200):
|
case <-time.After(time.Millisecond * 200):
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
@ -516,13 +670,9 @@ func (br *Buf) Write(p []byte) (n int, err error) {
|
||||||
br.rw.Lock()
|
br.rw.Lock()
|
||||||
defer br.rw.Unlock()
|
defer br.rw.Unlock()
|
||||||
n, err = br.buffer.Write(p)
|
n, err = br.buffer.Write(p)
|
||||||
select {
|
|
||||||
//case br.notify <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (br *Buf) Close() {
|
func (br *Buf) Close() {
|
||||||
//close(br.notify)
|
br.buffer.Reset()
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,7 +52,8 @@ import (
|
||||||
//
|
//
|
||||||
// If the caller has set w's ETag header formatted per RFC 7232, section 2.3,
|
// If the caller has set w's ETag header formatted per RFC 7232, section 2.3,
|
||||||
// ServeHTTP uses it to handle requests using If-Match, If-None-Match, or If-Range.
|
// ServeHTTP uses it to handle requests using If-Match, If-None-Match, or If-Range.
|
||||||
func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time.Time, size int64, RangeReaderFunc model.RangeReaderFunc) {
|
func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time.Time, size int64, RangeReadCloser model.RangeReadCloserIF) {
|
||||||
|
defer RangeReadCloser.Close()
|
||||||
setLastModified(w, modTime)
|
setLastModified(w, modTime)
|
||||||
done, rangeReq := checkPreconditions(w, r, modTime)
|
done, rangeReq := checkPreconditions(w, r, modTime)
|
||||||
if done {
|
if done {
|
||||||
|
@ -110,11 +111,19 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time
|
||||||
// or unknown file size, ignore the range request.
|
// or unknown file size, ignore the range request.
|
||||||
ranges = nil
|
ranges = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 使用请求的Context
|
||||||
|
// 不然从sendContent读不到数据,即使请求断开CopyBuffer也会一直堵塞
|
||||||
|
ctx := r.Context()
|
||||||
switch {
|
switch {
|
||||||
case len(ranges) == 0:
|
case len(ranges) == 0:
|
||||||
reader, err := RangeReaderFunc(context.Background(), http_range.Range{Length: -1})
|
reader, err := RangeReadCloser.RangeRead(ctx, http_range.Range{Length: -1})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
code = http.StatusRequestedRangeNotSatisfiable
|
||||||
|
if err == ErrExceedMaxConcurrency {
|
||||||
|
code = http.StatusTooManyRequests
|
||||||
|
}
|
||||||
|
http.Error(w, err.Error(), code)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
sendContent = reader
|
sendContent = reader
|
||||||
|
@ -131,9 +140,13 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time
|
||||||
// does not request multiple parts might not support
|
// does not request multiple parts might not support
|
||||||
// multipart responses."
|
// multipart responses."
|
||||||
ra := ranges[0]
|
ra := ranges[0]
|
||||||
sendContent, err = RangeReaderFunc(context.Background(), ra)
|
sendContent, err = RangeReadCloser.RangeRead(ctx, ra)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusRequestedRangeNotSatisfiable)
|
code = http.StatusRequestedRangeNotSatisfiable
|
||||||
|
if err == ErrExceedMaxConcurrency {
|
||||||
|
code = http.StatusTooManyRequests
|
||||||
|
}
|
||||||
|
http.Error(w, err.Error(), code)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
sendSize = ra.Length
|
sendSize = ra.Length
|
||||||
|
@ -158,7 +171,7 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time
|
||||||
pw.CloseWithError(err)
|
pw.CloseWithError(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
reader, err := RangeReaderFunc(context.Background(), ra)
|
reader, err := RangeReadCloser.RangeRead(ctx, ra)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pw.CloseWithError(err)
|
pw.CloseWithError(err)
|
||||||
return
|
return
|
||||||
|
@ -167,14 +180,12 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time
|
||||||
pw.CloseWithError(err)
|
pw.CloseWithError(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
//defer reader.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mw.Close()
|
mw.Close()
|
||||||
pw.Close()
|
pw.Close()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
//defer sendContent.Close()
|
|
||||||
|
|
||||||
w.Header().Set("Accept-Ranges", "bytes")
|
w.Header().Set("Accept-Ranges", "bytes")
|
||||||
if w.Header().Get("Content-Encoding") == "" {
|
if w.Header().Get("Content-Encoding") == "" {
|
||||||
|
@ -190,7 +201,11 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time
|
||||||
if written != sendSize {
|
if written != sendSize {
|
||||||
log.Warnf("Maybe size incorrect or reader not giving correct/full data, or connection closed before finish. written bytes: %d ,sendSize:%d, ", written, sendSize)
|
log.Warnf("Maybe size incorrect or reader not giving correct/full data, or connection closed before finish. written bytes: %d ,sendSize:%d, ", written, sendSize)
|
||||||
}
|
}
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
code = http.StatusInternalServerError
|
||||||
|
if err == ErrExceedMaxConcurrency {
|
||||||
|
code = http.StatusTooManyRequests
|
||||||
|
}
|
||||||
|
http.Error(w, err.Error(), code)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -239,7 +254,7 @@ func RequestHttp(ctx context.Context, httpMethod string, headerOverride http.Hea
|
||||||
_ = res.Body.Close()
|
_ = res.Body.Close()
|
||||||
msg := string(all)
|
msg := string(all)
|
||||||
log.Debugln(msg)
|
log.Debugln(msg)
|
||||||
return nil, fmt.Errorf("http request [%s] failure,status: %d response:%s", URL, res.StatusCode, msg)
|
return res, fmt.Errorf("http request [%s] failure,status: %d response:%s", URL, res.StatusCode, msg)
|
||||||
}
|
}
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,6 @@ package net
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/alist-org/alist/v3/pkg/utils"
|
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
|
@ -11,6 +10,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/alist-org/alist/v3/pkg/utils"
|
||||||
|
|
||||||
"github.com/alist-org/alist/v3/pkg/http_range"
|
"github.com/alist-org/alist/v3/pkg/http_range"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -15,6 +14,7 @@ import (
|
||||||
"github.com/alist-org/alist/v3/internal/model"
|
"github.com/alist-org/alist/v3/internal/model"
|
||||||
"github.com/alist-org/alist/v3/internal/offline_download/tool"
|
"github.com/alist-org/alist/v3/internal/offline_download/tool"
|
||||||
"github.com/alist-org/alist/v3/internal/setting"
|
"github.com/alist-org/alist/v3/internal/setting"
|
||||||
|
"github.com/alist-org/alist/v3/pkg/utils"
|
||||||
"github.com/hekmon/transmissionrpc/v3"
|
"github.com/hekmon/transmissionrpc/v3"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
@ -92,7 +92,7 @@ func (t *Transmission) AddURL(args *tool.AddUrlArgs) (string, error) {
|
||||||
buffer := new(bytes.Buffer)
|
buffer := new(bytes.Buffer)
|
||||||
encoder := base64.NewEncoder(base64.StdEncoding, buffer)
|
encoder := base64.NewEncoder(base64.StdEncoding, buffer)
|
||||||
// Stream file to the encoder
|
// Stream file to the encoder
|
||||||
if _, err = io.Copy(encoder, resp.Body); err != nil {
|
if _, err = utils.CopyWithBuffer(encoder, resp.Body); err != nil {
|
||||||
return "", errors.Wrap(err, "can't copy file content into the base64 encoder")
|
return "", errors.Wrap(err, "can't copy file content into the base64 encoder")
|
||||||
}
|
}
|
||||||
// Flush last bytes
|
// Flush last bytes
|
||||||
|
|
|
@ -122,7 +122,8 @@ const InMemoryBufMaxSizeBytes = InMemoryBufMaxSize * 1024 * 1024
|
||||||
// also support a peeking RangeRead at very start, but won't buffer more than 10MB data in memory
|
// also support a peeking RangeRead at very start, but won't buffer more than 10MB data in memory
|
||||||
func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
|
func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
|
||||||
if httpRange.Length == -1 {
|
if httpRange.Length == -1 {
|
||||||
httpRange.Length = f.GetSize()
|
// 参考 internal/net/request.go
|
||||||
|
httpRange.Length = f.GetSize() - httpRange.Start
|
||||||
}
|
}
|
||||||
if f.peekBuff != nil && httpRange.Start < int64(f.peekBuff.Len()) && httpRange.Start+httpRange.Length-1 < int64(f.peekBuff.Len()) {
|
if f.peekBuff != nil && httpRange.Start < int64(f.peekBuff.Len()) && httpRange.Start+httpRange.Length-1 < int64(f.peekBuff.Len()) {
|
||||||
return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil
|
return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil
|
||||||
|
@ -210,7 +211,7 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error)
|
||||||
// RangeRead is not thread-safe, pls use it in single thread only.
|
// RangeRead is not thread-safe, pls use it in single thread only.
|
||||||
func (ss *SeekableStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
|
func (ss *SeekableStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
|
||||||
if httpRange.Length == -1 {
|
if httpRange.Length == -1 {
|
||||||
httpRange.Length = ss.GetSize()
|
httpRange.Length = ss.GetSize() - httpRange.Start
|
||||||
}
|
}
|
||||||
if ss.mFile != nil {
|
if ss.mFile != nil {
|
||||||
return io.NewSectionReader(ss.mFile, httpRange.Start, httpRange.Length), nil
|
return io.NewSectionReader(ss.mFile, httpRange.Start, httpRange.Length), nil
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/alist-org/alist/v3/internal/errs"
|
|
||||||
"github.com/alist-org/alist/v3/internal/model"
|
"github.com/alist-org/alist/v3/internal/model"
|
||||||
"github.com/alist-org/alist/v3/internal/net"
|
"github.com/alist-org/alist/v3/internal/net"
|
||||||
"github.com/alist-org/alist/v3/pkg/http_range"
|
"github.com/alist-org/alist/v3/pkg/http_range"
|
||||||
|
@ -17,7 +16,6 @@ func GetRangeReadCloserFromLink(size int64, link *model.Link) (model.RangeReadCl
|
||||||
if len(link.URL) == 0 {
|
if len(link.URL) == 0 {
|
||||||
return nil, fmt.Errorf("can't create RangeReadCloser since URL is empty in link")
|
return nil, fmt.Errorf("can't create RangeReadCloser since URL is empty in link")
|
||||||
}
|
}
|
||||||
//remoteClosers := utils.EmptyClosers()
|
|
||||||
rangeReaderFunc := func(ctx context.Context, r http_range.Range) (io.ReadCloser, error) {
|
rangeReaderFunc := func(ctx context.Context, r http_range.Range) (io.ReadCloser, error) {
|
||||||
if link.Concurrency != 0 || link.PartSize != 0 {
|
if link.Concurrency != 0 || link.PartSize != 0 {
|
||||||
header := net.ProcessHeader(http.Header{}, link.Header)
|
header := net.ProcessHeader(http.Header{}, link.Header)
|
||||||
|
@ -32,37 +30,29 @@ func GetRangeReadCloserFromLink(size int64, link *model.Link) (model.RangeReadCl
|
||||||
HeaderRef: header,
|
HeaderRef: header,
|
||||||
}
|
}
|
||||||
rc, err := down.Download(ctx, req)
|
rc, err := down.Download(ctx, req)
|
||||||
if err != nil {
|
return rc, err
|
||||||
return nil, errs.NewErr(err, "GetReadCloserFromLink failed")
|
|
||||||
}
|
|
||||||
return rc, nil
|
|
||||||
|
|
||||||
}
|
}
|
||||||
if len(link.URL) > 0 {
|
response, err := RequestRangedHttp(ctx, link, r.Start, r.Length)
|
||||||
response, err := RequestRangedHttp(ctx, link, r.Start, r.Length)
|
if err != nil {
|
||||||
if err != nil {
|
if response == nil {
|
||||||
if response == nil {
|
return nil, fmt.Errorf("http request failure, err:%s", err)
|
||||||
return nil, fmt.Errorf("http request failure, err:%s", err)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("http request failure,status: %d err:%s", response.StatusCode, err)
|
|
||||||
}
|
}
|
||||||
if r.Start == 0 && (r.Length == -1 || r.Length == size) || response.StatusCode == http.StatusPartialContent ||
|
return nil, err
|
||||||
checkContentRange(&response.Header, r.Start) {
|
}
|
||||||
return response.Body, nil
|
if r.Start == 0 && (r.Length == -1 || r.Length == size) || response.StatusCode == http.StatusPartialContent ||
|
||||||
} else if response.StatusCode == http.StatusOK {
|
checkContentRange(&response.Header, r.Start) {
|
||||||
log.Warnf("remote http server not supporting range request, expect low perfromace!")
|
|
||||||
readCloser, err := net.GetRangedHttpReader(response.Body, r.Start, r.Length)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return readCloser, nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return response.Body, nil
|
return response.Body, nil
|
||||||
|
} else if response.StatusCode == http.StatusOK {
|
||||||
|
log.Warnf("remote http server not supporting range request, expect low perfromace!")
|
||||||
|
readCloser, err := net.GetRangedHttpReader(response.Body, r.Start, r.Length)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return readCloser, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, errs.NotSupport
|
return response.Body, nil
|
||||||
}
|
}
|
||||||
resultRangeReadCloser := model.RangeReadCloser{RangeReader: rangeReaderFunc}
|
resultRangeReadCloser := model.RangeReadCloser{RangeReader: rangeReaderFunc}
|
||||||
return &resultRangeReadCloser, nil
|
return &resultRangeReadCloser, nil
|
||||||
|
|
|
@ -27,16 +27,11 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.
|
||||||
return nil
|
return nil
|
||||||
} else if link.RangeReadCloser != nil {
|
} else if link.RangeReadCloser != nil {
|
||||||
attachFileName(w, file)
|
attachFileName(w, file)
|
||||||
net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), link.RangeReadCloser.RangeRead)
|
net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), link.RangeReadCloser)
|
||||||
defer func() {
|
|
||||||
_ = link.RangeReadCloser.Close()
|
|
||||||
}()
|
|
||||||
return nil
|
return nil
|
||||||
} else if link.Concurrency != 0 || link.PartSize != 0 {
|
} else if link.Concurrency != 0 || link.PartSize != 0 {
|
||||||
attachFileName(w, file)
|
attachFileName(w, file)
|
||||||
size := file.GetSize()
|
size := file.GetSize()
|
||||||
//var finalClosers model.Closers
|
|
||||||
finalClosers := utils.EmptyClosers()
|
|
||||||
header := net.ProcessHeader(r.Header, link.Header)
|
header := net.ProcessHeader(r.Header, link.Header)
|
||||||
rangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
|
rangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
|
||||||
down := net.NewDownloader(func(d *net.Downloader) {
|
down := net.NewDownloader(func(d *net.Downloader) {
|
||||||
|
@ -50,16 +45,14 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.
|
||||||
HeaderRef: header,
|
HeaderRef: header,
|
||||||
}
|
}
|
||||||
rc, err := down.Download(ctx, req)
|
rc, err := down.Download(ctx, req)
|
||||||
finalClosers.Add(rc)
|
|
||||||
return rc, err
|
return rc, err
|
||||||
}
|
}
|
||||||
net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), rangeReader)
|
net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), &model.RangeReadCloser{RangeReader: rangeReader})
|
||||||
defer finalClosers.Close()
|
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
//transparent proxy
|
//transparent proxy
|
||||||
header := net.ProcessHeader(r.Header, link.Header)
|
header := net.ProcessHeader(r.Header, link.Header)
|
||||||
res, err := net.RequestHttp(context.Background(), r.Method, header, link.URL)
|
res, err := net.RequestHttp(r.Context(), r.Method, header, link.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -72,7 +65,7 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.
|
||||||
if r.Method == http.MethodHead {
|
if r.Method == http.MethodHead {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
_, err = io.Copy(w, res.Body)
|
_, err = utils.CopyWithBuffer(w, res.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -281,10 +281,11 @@ func ArchiveDown(c *gin.Context) {
|
||||||
link, _, err := fs.ArchiveDriverExtract(c, archiveRawPath, model.ArchiveInnerArgs{
|
link, _, err := fs.ArchiveDriverExtract(c, archiveRawPath, model.ArchiveInnerArgs{
|
||||||
ArchiveArgs: model.ArchiveArgs{
|
ArchiveArgs: model.ArchiveArgs{
|
||||||
LinkArgs: model.LinkArgs{
|
LinkArgs: model.LinkArgs{
|
||||||
IP: c.ClientIP(),
|
IP: c.ClientIP(),
|
||||||
Header: c.Request.Header,
|
Header: c.Request.Header,
|
||||||
Type: c.Query("type"),
|
Type: c.Query("type"),
|
||||||
HttpReq: c.Request,
|
HttpReq: c.Request,
|
||||||
|
Redirect: true,
|
||||||
},
|
},
|
||||||
Password: password,
|
Password: password,
|
||||||
},
|
},
|
||||||
|
|
|
@ -31,10 +31,11 @@ func Down(c *gin.Context) {
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
link, _, err := fs.Link(c, rawPath, model.LinkArgs{
|
link, _, err := fs.Link(c, rawPath, model.LinkArgs{
|
||||||
IP: c.ClientIP(),
|
IP: c.ClientIP(),
|
||||||
Header: c.Request.Header,
|
Header: c.Request.Header,
|
||||||
Type: c.Query("type"),
|
Type: c.Query("type"),
|
||||||
HttpReq: c.Request,
|
HttpReq: c.Request,
|
||||||
|
Redirect: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ErrorResp(c, err, 500)
|
common.ErrorResp(c, err, 500)
|
||||||
|
|
|
@ -6,13 +6,14 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/pkg/errors"
|
|
||||||
"io"
|
"io"
|
||||||
"path"
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"github.com/alist-org/alist/v3/internal/errs"
|
"github.com/alist-org/alist/v3/internal/errs"
|
||||||
"github.com/alist-org/alist/v3/internal/fs"
|
"github.com/alist-org/alist/v3/internal/fs"
|
||||||
"github.com/alist-org/alist/v3/internal/model"
|
"github.com/alist-org/alist/v3/internal/model"
|
||||||
|
@ -173,15 +174,27 @@ func (b *s3Backend) GetObject(ctx context.Context, bucketName, objectName string
|
||||||
if link.RangeReadCloser == nil && link.MFile == nil && len(link.URL) == 0 {
|
if link.RangeReadCloser == nil && link.MFile == nil && len(link.URL) == 0 {
|
||||||
return nil, fmt.Errorf("the remote storage driver need to be enhanced to support s3")
|
return nil, fmt.Errorf("the remote storage driver need to be enhanced to support s3")
|
||||||
}
|
}
|
||||||
remoteFileSize := file.GetSize()
|
|
||||||
remoteClosers := utils.EmptyClosers()
|
var rdr io.ReadCloser
|
||||||
rangeReaderFunc := func(ctx context.Context, start, length int64) (io.ReadCloser, error) {
|
length := int64(-1)
|
||||||
|
start := int64(0)
|
||||||
|
if rnge != nil {
|
||||||
|
start, length = rnge.Start, rnge.Length
|
||||||
|
}
|
||||||
|
// 参考 server/common/proxy.go
|
||||||
|
if link.MFile != nil {
|
||||||
|
_, err := link.MFile.Seek(start, io.SeekStart)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rdr = link.MFile
|
||||||
|
} else {
|
||||||
|
remoteFileSize := file.GetSize()
|
||||||
if length >= 0 && start+length >= remoteFileSize {
|
if length >= 0 && start+length >= remoteFileSize {
|
||||||
length = -1
|
length = -1
|
||||||
}
|
}
|
||||||
rrc := link.RangeReadCloser
|
rrc := link.RangeReadCloser
|
||||||
if len(link.URL) > 0 {
|
if len(link.URL) > 0 {
|
||||||
|
|
||||||
rangedRemoteLink := &model.Link{
|
rangedRemoteLink := &model.Link{
|
||||||
URL: link.URL,
|
URL: link.URL,
|
||||||
Header: link.Header,
|
Header: link.Header,
|
||||||
|
@ -194,35 +207,12 @@ func (b *s3Backend) GetObject(ctx context.Context, bucketName, objectName string
|
||||||
}
|
}
|
||||||
if rrc != nil {
|
if rrc != nil {
|
||||||
remoteReader, err := rrc.RangeRead(ctx, http_range.Range{Start: start, Length: length})
|
remoteReader, err := rrc.RangeRead(ctx, http_range.Range{Start: start, Length: length})
|
||||||
remoteClosers.AddClosers(rrc.GetClosers())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return remoteReader, nil
|
rdr = utils.ReadCloser{Reader: remoteReader, Closer: rrc}
|
||||||
}
|
} else {
|
||||||
if link.MFile != nil {
|
return nil, errs.NotSupport
|
||||||
_, err := link.MFile.Seek(start, io.SeekStart)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
//remoteClosers.Add(remoteLink.MFile)
|
|
||||||
//keep reuse same MFile and close at last.
|
|
||||||
remoteClosers.Add(link.MFile)
|
|
||||||
return io.NopCloser(link.MFile), nil
|
|
||||||
}
|
|
||||||
return nil, errs.NotSupport
|
|
||||||
}
|
|
||||||
|
|
||||||
var rdr io.ReadCloser
|
|
||||||
if rnge != nil {
|
|
||||||
rdr, err = rangeReaderFunc(ctx, rnge.Start, rnge.Length)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
rdr, err = rangeReaderFunc(ctx, 0, -1)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -263,7 +263,7 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request) (sta
|
||||||
w.Header().Set("Cache-Control", "max-age=0, no-cache, no-store, must-revalidate")
|
w.Header().Set("Cache-Control", "max-age=0, no-cache, no-store, must-revalidate")
|
||||||
http.Redirect(w, r, u, http.StatusFound)
|
http.Redirect(w, r, u, http.StatusFound)
|
||||||
} else {
|
} else {
|
||||||
link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{IP: utils.ClientIP(r), Header: r.Header, HttpReq: r})
|
link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{IP: utils.ClientIP(r), Header: r.Header, HttpReq: r, Redirect: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return http.StatusInternalServerError, err
|
return http.StatusInternalServerError, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue