perf(copy): use multi-thread downloader (close #5000)

pull/4908/head
Andy Hsu 2023-08-13 15:31:49 +08:00
parent 0b675d6c02
commit 5606c23768
7 changed files with 57 additions and 36 deletions

View File

@ -13,7 +13,7 @@ import (
) )
func RequestRangedHttp(r *http.Request, link *model.Link, offset, length int64) (*http.Response, error) { func RequestRangedHttp(r *http.Request, link *model.Link, offset, length int64) (*http.Response, error) {
header := net.ProcessHeader(&http.Header{}, &link.Header) header := net.ProcessHeader(http.Header{}, link.Header)
header = http_range.ApplyRangeToHttpHeader(http_range.Range{Start: offset, Length: length}, header) header = http_range.ApplyRangeToHttpHeader(http_range.Range{Start: offset, Length: length}, header)
return net.RequestHttp("GET", header, link.URL) return net.RequestHttp("GET", header, link.URL)

View File

@ -94,7 +94,7 @@ func copyFileBetween2Storages(tsk *task.Task[uint64], srcStorage, dstStorage dri
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed get [%s] link", srcFilePath) return errors.WithMessagef(err, "failed get [%s] link", srcFilePath)
} }
stream, err := getFileStreamFromLink(srcFile, link) stream, err := getFileStreamFromLink(tsk.Ctx, srcFile, link)
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed get [%s] stream", srcFilePath) return errors.WithMessagef(err, "failed get [%s] stream", srcFilePath)
} }

View File

@ -1,18 +1,21 @@
package fs package fs
import ( import (
"github.com/alist-org/alist/v3/pkg/http_range" "context"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/alist-org/alist/v3/internal/net"
"github.com/alist-org/alist/v3/pkg/http_range"
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/pkg/utils"
"github.com/alist-org/alist/v3/server/common" "github.com/alist-org/alist/v3/server/common"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
func getFileStreamFromLink(file model.Obj, link *model.Link) (*model.FileStream, error) { func getFileStreamFromLink(ctx context.Context, file model.Obj, link *model.Link) (*model.FileStream, error) {
var rc io.ReadCloser var rc io.ReadCloser
var err error var err error
mimetype := utils.GetMimeType(file.GetName()) mimetype := utils.GetMimeType(file.GetName())
@ -23,6 +26,21 @@ func getFileStreamFromLink(file model.Obj, link *model.Link) (*model.FileStream,
} }
} else if link.ReadSeekCloser != nil { } else if link.ReadSeekCloser != nil {
rc = link.ReadSeekCloser rc = link.ReadSeekCloser
} else if link.Concurrency != 0 || link.PartSize != 0 {
down := net.NewDownloader(func(d *net.Downloader) {
d.Concurrency = link.Concurrency
d.PartSize = link.PartSize
})
req := &net.HttpRequestParams{
URL: link.URL,
Range: http_range.Range{Length: -1},
Size: file.GetSize(),
HeaderRef: link.Header,
}
rc, err = down.Download(ctx, req)
if err != nil {
return nil, err
}
} else { } else {
//TODO: add accelerator //TODO: add accelerator
req, err := http.NewRequest(http.MethodGet, link.URL, nil) req, err := http.NewRequest(http.MethodGet, link.URL, nil)

View File

@ -3,9 +3,6 @@ package net
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/alist-org/alist/v3/pkg/http_range"
"github.com/aws/aws-sdk-go/aws/awsutil"
log "github.com/sirupsen/logrus"
"io" "io"
"math" "math"
"net/http" "net/http"
@ -13,6 +10,10 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/alist-org/alist/v3/pkg/http_range"
"github.com/aws/aws-sdk-go/aws/awsutil"
log "github.com/sirupsen/logrus"
) )
// DefaultDownloadPartSize is the default range of bytes to get at a time when // DefaultDownloadPartSize is the default range of bytes to get at a time when
@ -60,7 +61,7 @@ func NewDownloader(options ...func(*Downloader)) *Downloader {
// cache some data, then return Reader with assembled data // cache some data, then return Reader with assembled data
// Supports range, do not support unknown FileSize, and will fail if FileSize is incorrect // Supports range, do not support unknown FileSize, and will fail if FileSize is incorrect
// memory usage is at about Concurrency*PartSize, use this wisely // memory usage is at about Concurrency*PartSize, use this wisely
func (d Downloader) Download(ctx context.Context, p *HttpRequestParams) (readCloser *io.ReadCloser, err error) { func (d Downloader) Download(ctx context.Context, p *HttpRequestParams) (readCloser io.ReadCloser, err error) {
var finalP HttpRequestParams var finalP HttpRequestParams
awsutil.Copy(&finalP, p) awsutil.Copy(&finalP, p)
@ -107,7 +108,7 @@ type downloader struct {
} }
// download performs the implementation of the object download across ranged GETs. // download performs the implementation of the object download across ranged GETs.
func (d *downloader) download() (*io.ReadCloser, error) { func (d *downloader) download() (io.ReadCloser, error) {
d.ctx, d.cancel = context.WithCancel(d.ctx) d.ctx, d.cancel = context.WithCancel(d.ctx)
pos := d.params.Range.Start pos := d.params.Range.Start
@ -133,7 +134,7 @@ func (d *downloader) download() (*io.ReadCloser, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &resp.Body, nil return resp.Body, nil
} }
// workers // workers
@ -152,7 +153,7 @@ func (d *downloader) download() (*io.ReadCloser, error) {
var rc io.ReadCloser = NewMultiReadCloser(d.chunks[0].buf, d.interrupt, d.finishBuf) var rc io.ReadCloser = NewMultiReadCloser(d.chunks[0].buf, d.interrupt, d.finishBuf)
// Return error // Return error
return &rc, d.err return rc, d.err
} }
func (d *downloader) sendChunkTask() *chunk { func (d *downloader) sendChunkTask() *chunk {
ch := &d.chunks[d.nextChunk] ch := &d.chunks[d.nextChunk]
@ -384,7 +385,7 @@ type HttpRequestParams struct {
URL string URL string
//only want data within this range //only want data within this range
Range http_range.Range Range http_range.Range
HeaderRef *http.Header HeaderRef http.Header
//total file size //total file size
Size int64 Size int64
} }

View File

@ -2,13 +2,6 @@ package net
import ( import (
"fmt" "fmt"
"github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/conf"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/pkg/http_range"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"io" "io"
"mime" "mime"
"mime/multipart" "mime/multipart"
@ -18,6 +11,14 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/conf"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/pkg/http_range"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
) )
//this file is inspired by GO_SDK net.http.ServeContent //this file is inspired by GO_SDK net.http.ServeContent
@ -109,7 +110,7 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time
} }
switch { switch {
case len(ranges) == 0: case len(ranges) == 0:
reader, err := RangeReaderFunc(http_range.Range{0, -1}) reader, err := RangeReaderFunc(http_range.Range{Length: -1})
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
@ -191,29 +192,29 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time
} }
//defer sendContent.Close() //defer sendContent.Close()
} }
func ProcessHeader(origin, override *http.Header) *http.Header { func ProcessHeader(origin, override http.Header) http.Header {
result := http.Header{} result := http.Header{}
// client header // client header
for h, val := range *origin { for h, val := range origin {
if utils.SliceContains(conf.SlicesMap[conf.ProxyIgnoreHeaders], strings.ToLower(h)) { if utils.SliceContains(conf.SlicesMap[conf.ProxyIgnoreHeaders], strings.ToLower(h)) {
continue continue
} }
result[h] = val result[h] = val
} }
// needed header // needed header
for h, val := range *override { for h, val := range override {
result[h] = val result[h] = val
} }
return &result return result
} }
// RequestHttp deal with Header properly then send the request // RequestHttp deal with Header properly then send the request
func RequestHttp(httpMethod string, headerOverride *http.Header, URL string) (*http.Response, error) { func RequestHttp(httpMethod string, headerOverride http.Header, URL string) (*http.Response, error) {
req, err := http.NewRequest(httpMethod, URL, nil) req, err := http.NewRequest(httpMethod, URL, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header = *headerOverride req.Header = headerOverride
log.Debugln("request Header: ", req.Header) log.Debugln("request Header: ", req.Header)
log.Debugln("request URL: ", URL) log.Debugln("request URL: ", URL)
res, err := HttpClient().Do(req) res, err := HttpClient().Do(req)

View File

@ -120,10 +120,10 @@ func (r Range) contentRange(size int64) string {
} }
// ApplyRangeToHttpHeader for http request header // ApplyRangeToHttpHeader for http request header
func ApplyRangeToHttpHeader(p Range, headerRef *http.Header) *http.Header { func ApplyRangeToHttpHeader(p Range, headerRef http.Header) http.Header {
header := headerRef header := headerRef
if header == nil { if header == nil {
header = &http.Header{} header = http.Header{}
} }
if p.Start == 0 && p.Length < 0 { if p.Start == 0 && p.Length < 0 {
header.Del("Range") header.Del("Range")

View File

@ -3,16 +3,17 @@ package common
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"net/http"
"net/url"
"sync"
"github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/net" "github.com/alist-org/alist/v3/internal/net"
"github.com/alist-org/alist/v3/pkg/http_range" "github.com/alist-org/alist/v3/pkg/http_range"
"github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/pkg/utils"
"github.com/pkg/errors" "github.com/pkg/errors"
"io"
"net/http"
"net/url"
"sync"
) )
func HttpClient() *http.Client { func HttpClient() *http.Client {
@ -52,7 +53,7 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.
size := file.GetSize() size := file.GetSize()
//var finalClosers model.Closers //var finalClosers model.Closers
finalClosers := utils.NewClosers() finalClosers := utils.NewClosers()
header := net.ProcessHeader(&r.Header, &link.Header) header := net.ProcessHeader(r.Header, link.Header)
rangeReader := func(httpRange http_range.Range) (io.ReadCloser, error) { rangeReader := func(httpRange http_range.Range) (io.ReadCloser, error) {
down := net.NewDownloader(func(d *net.Downloader) { down := net.NewDownloader(func(d *net.Downloader) {
d.Concurrency = link.Concurrency d.Concurrency = link.Concurrency
@ -65,15 +66,15 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.
HeaderRef: header, HeaderRef: header,
} }
rc, err := down.Download(context.Background(), req) rc, err := down.Download(context.Background(), req)
finalClosers.Add(*rc) finalClosers.Add(rc)
return *rc, err return rc, err
} }
net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), rangeReader) net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), rangeReader)
defer finalClosers.Close() defer finalClosers.Close()
return nil return nil
} else { } else {
//transparent proxy //transparent proxy
header := net.ProcessHeader(&r.Header, &link.Header) header := net.ProcessHeader(r.Header, link.Header)
res, err := net.RequestHttp(r.Method, header, link.URL) res, err := net.RequestHttp(r.Method, header, link.URL)
if err != nil { if err != nil {
return err return err