mirror of https://github.com/Xhofe/alist
fix: copy tasks using multi-thread downloader can't be canceled (#5028)
#4981 relatedpull/5042/head
parent
ed550594da
commit
1e3950c847
|
@ -1,6 +1,7 @@
|
||||||
package net
|
package net
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -202,7 +203,6 @@ func (d *downloader) downloadPart() {
|
||||||
//defer d.wg.Done()
|
//defer d.wg.Done()
|
||||||
for {
|
for {
|
||||||
c, ok := <-d.chunkChannel
|
c, ok := <-d.chunkChannel
|
||||||
log.Debugf("downloadPart tried to get chunk")
|
|
||||||
if !ok {
|
if !ok {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -211,7 +211,7 @@ func (d *downloader) downloadPart() {
|
||||||
// of download producer.
|
// of download producer.
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
log.Debugf("downloadPart tried to get chunk")
|
||||||
if err := d.downloadChunk(&c); err != nil {
|
if err := d.downloadChunk(&c); err != nil {
|
||||||
d.setErr(err)
|
d.setErr(err)
|
||||||
}
|
}
|
||||||
|
@ -220,7 +220,7 @@ func (d *downloader) downloadPart() {
|
||||||
|
|
||||||
// downloadChunk downloads the chunk
|
// downloadChunk downloads the chunk
|
||||||
func (d *downloader) downloadChunk(ch *chunk) error {
|
func (d *downloader) downloadChunk(ch *chunk) error {
|
||||||
log.Debugf("start new chunk %+v buffer_id =%d", ch, ch.buf.buffer.id)
|
log.Debugf("start new chunk %+v buffer_id =%d", ch, ch.id)
|
||||||
var n int64
|
var n int64
|
||||||
var err error
|
var err error
|
||||||
params := d.getParamsFromChunk(ch)
|
params := d.getParamsFromChunk(ch)
|
||||||
|
@ -262,6 +262,7 @@ func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
//only check file size on the first task
|
//only check file size on the first task
|
||||||
if ch.id == 0 {
|
if ch.id == 0 {
|
||||||
err = d.checkTotalBytes(resp)
|
err = d.checkTotalBytes(resp)
|
||||||
|
@ -279,7 +280,6 @@ func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int
|
||||||
err = fmt.Errorf("chunk download size incorrect, expected=%d, got=%d", ch.size, n)
|
err = fmt.Errorf("chunk download size incorrect, expected=%d, got=%d", ch.size, n)
|
||||||
return n, &errReadingBody{err: err}
|
return n, &errReadingBody{err: err}
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
@ -402,13 +402,8 @@ func (e *errReadingBody) Unwrap() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
type MultiReadCloser struct {
|
type MultiReadCloser struct {
|
||||||
io.ReadCloser
|
|
||||||
|
|
||||||
//total int //total bufArr
|
|
||||||
//wPos int //current reader wPos
|
|
||||||
cfg *cfg
|
cfg *cfg
|
||||||
closer closerFunc
|
closer closerFunc
|
||||||
//getBuf getBufFunc
|
|
||||||
finish finishBufFUnc
|
finish finishBufFUnc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -449,99 +444,26 @@ func (mr MultiReadCloser) Close() error {
|
||||||
return mr.closer()
|
return mr.closer()
|
||||||
}
|
}
|
||||||
|
|
||||||
type Buffer struct {
|
|
||||||
data []byte
|
|
||||||
wPos int //writer position
|
|
||||||
id int
|
|
||||||
rPos int //reader position
|
|
||||||
lock sync.Mutex
|
|
||||||
|
|
||||||
once bool //combined use with notify & lock, to get notify once
|
|
||||||
notify chan int // notifies new writes
|
|
||||||
}
|
|
||||||
|
|
||||||
func (buf *Buffer) Write(p []byte) (n int, err error) {
|
|
||||||
inSize := len(p)
|
|
||||||
if inSize == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if inSize > len(buf.data)-buf.wPos {
|
|
||||||
return 0, fmt.Errorf("exceeding buffer max size,inSize=%d ,buf.data.len=%d , buf.wPos=%d",
|
|
||||||
inSize, len(buf.data), buf.wPos)
|
|
||||||
}
|
|
||||||
copy(buf.data[buf.wPos:], p)
|
|
||||||
buf.wPos += inSize
|
|
||||||
|
|
||||||
//give read a notice if once==true
|
|
||||||
buf.lock.Lock()
|
|
||||||
if buf.once == true {
|
|
||||||
buf.notify <- inSize //struct{}{}
|
|
||||||
}
|
|
||||||
buf.once = false
|
|
||||||
buf.lock.Unlock()
|
|
||||||
|
|
||||||
return inSize, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (buf *Buffer) getPos() (n int) {
|
|
||||||
return buf.wPos
|
|
||||||
}
|
|
||||||
func (buf *Buffer) reset() {
|
|
||||||
buf.wPos = 0
|
|
||||||
buf.rPos = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// waitTillNewWrite notify caller that new write happens
|
|
||||||
func (buf *Buffer) waitTillNewWrite(pos int) error {
|
|
||||||
//log.Debugf("waitTillNewWrite, current wPos=%d", pos)
|
|
||||||
var err error
|
|
||||||
|
|
||||||
//defer buffer.lock.Unlock()
|
|
||||||
if pos >= len(buf.data) {
|
|
||||||
err = fmt.Errorf("there will not be any new write")
|
|
||||||
} else if pos > buf.wPos {
|
|
||||||
err = fmt.Errorf("illegal read position")
|
|
||||||
} else if pos == buf.wPos {
|
|
||||||
buf.lock.Lock()
|
|
||||||
buf.once = true
|
|
||||||
//buffer.wg1.Add(1)
|
|
||||||
buf.lock.Unlock()
|
|
||||||
//wait for write
|
|
||||||
log.Debugf("waitTillNewWrite wait for notify")
|
|
||||||
writes := <-buf.notify
|
|
||||||
log.Debugf("waitTillNewWrite got new write from notify, last writes:%+v", writes)
|
|
||||||
//if pos >= buf.wPos {
|
|
||||||
// //wrote 0 bytes
|
|
||||||
// return fmt.Errorf("write has error")
|
|
||||||
//}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
//only case: wPos < buffer.wPos
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
type Buf struct {
|
type Buf struct {
|
||||||
buffer *Buffer // Buffer we read from
|
buffer *bytes.Buffer
|
||||||
size int //expected size
|
size int //expected size
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
off int
|
||||||
|
rw sync.RWMutex
|
||||||
|
notify chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBuf is a buffer that can have 1 read & 1 write at the same time.
|
// NewBuf is a buffer that can have 1 read & 1 write at the same time.
|
||||||
// when read is faster write, immediately feed data to read after written
|
// when read is faster write, immediately feed data to read after written
|
||||||
func NewBuf(ctx context.Context, maxSize int, id int) *Buf {
|
func NewBuf(ctx context.Context, maxSize int, id int) *Buf {
|
||||||
d := make([]byte, maxSize)
|
d := make([]byte, maxSize)
|
||||||
buffer := &Buffer{data: d, id: id, notify: make(chan int)}
|
return &Buf{ctx: ctx, buffer: bytes.NewBuffer(d), size: maxSize, notify: make(chan struct{})}
|
||||||
buffer.reset()
|
|
||||||
return &Buf{ctx: ctx, buffer: buffer, size: maxSize}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
func (br *Buf) Reset(size int) {
|
func (br *Buf) Reset(size int) {
|
||||||
br.buffer.reset()
|
br.buffer.Reset()
|
||||||
br.size = size
|
br.size = size
|
||||||
}
|
br.off = 0
|
||||||
func (br *Buf) GetId() int {
|
|
||||||
return br.buffer.id
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (br *Buf) Read(p []byte) (n int, err error) {
|
func (br *Buf) Read(p []byte) (n int, err error) {
|
||||||
|
@ -551,48 +473,49 @@ func (br *Buf) Read(p []byte) (n int, err error) {
|
||||||
if len(p) == 0 {
|
if len(p) == 0 {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
if br.buffer.rPos == br.size {
|
if br.off >= br.size {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
//persist buffer position as another thread is keep increasing it
|
br.rw.RLock()
|
||||||
bufPos := br.buffer.getPos()
|
n, err = br.buffer.Read(p)
|
||||||
outSize := bufPos - br.buffer.rPos
|
br.rw.RUnlock()
|
||||||
|
if err == nil {
|
||||||
if outSize == 0 {
|
br.off += n
|
||||||
//var wg sync.WaitGroup
|
return n, err
|
||||||
err := br.waitTillNewWrite(br.buffer.rPos)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
bufPos = br.buffer.getPos()
|
|
||||||
outSize = bufPos - br.buffer.rPos
|
|
||||||
}
|
}
|
||||||
|
if err != io.EOF {
|
||||||
if len(p) < outSize {
|
return n, err
|
||||||
// p is not big enough
|
|
||||||
outSize = len(p)
|
|
||||||
}
|
}
|
||||||
copy(p, br.buffer.data[br.buffer.rPos:br.buffer.rPos+outSize])
|
if n != 0 {
|
||||||
br.buffer.rPos += outSize
|
br.off += n
|
||||||
if br.buffer.rPos == br.size {
|
return n, nil
|
||||||
err = io.EOF
|
}
|
||||||
|
// n==0, err==io.EOF
|
||||||
|
// wait for new write for 200ms
|
||||||
|
select {
|
||||||
|
case <-br.ctx.Done():
|
||||||
|
return 0, br.ctx.Err()
|
||||||
|
case <-br.notify:
|
||||||
|
return 0, nil
|
||||||
|
case <-time.After(time.Millisecond * 200):
|
||||||
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return outSize, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// waitTillNewWrite is expensive, since we just checked that no new data, wait 0.2s
|
|
||||||
func (br *Buf) waitTillNewWrite(pos int) error {
|
|
||||||
time.Sleep(200 * time.Millisecond)
|
|
||||||
return br.buffer.waitTillNewWrite(br.buffer.rPos)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (br *Buf) Write(p []byte) (n int, err error) {
|
func (br *Buf) Write(p []byte) (n int, err error) {
|
||||||
if err := br.ctx.Err(); err != nil {
|
if err := br.ctx.Err(); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return br.buffer.Write(p)
|
br.rw.Lock()
|
||||||
|
defer br.rw.Unlock()
|
||||||
|
n, err = br.buffer.Write(p)
|
||||||
|
select {
|
||||||
|
case br.notify <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (br *Buf) Close() {
|
func (br *Buf) Close() {
|
||||||
close(br.buffer.notify)
|
close(br.notify)
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,14 +7,15 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/alist-org/alist/v3/pkg/http_range"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/alist-org/alist/v3/pkg/http_range"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
var buf22MB = make([]byte, 1024*1024*22)
|
var buf22MB = make([]byte, 1024*1024*22)
|
||||||
|
@ -55,7 +56,7 @@ func TestDownloadOrder(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expect no error, got %v", err)
|
t.Fatalf("expect no error, got %v", err)
|
||||||
}
|
}
|
||||||
resultBuf, err := io.ReadAll(*readCloser)
|
resultBuf, err := io.ReadAll(readCloser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expect no error, got %v", err)
|
t.Fatalf("expect no error, got %v", err)
|
||||||
}
|
}
|
||||||
|
@ -111,7 +112,7 @@ func TestDownloadSingle(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expect no error, got %v", err)
|
t.Fatalf("expect no error, got %v", err)
|
||||||
}
|
}
|
||||||
resultBuf, err := io.ReadAll(*readCloser)
|
resultBuf, err := io.ReadAll(readCloser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expect no error, got %v", err)
|
t.Fatalf("expect no error, got %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -215,14 +215,10 @@ func RequestHttp(httpMethod string, headerOverride http.Header, URL string) (*ht
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header = headerOverride
|
req.Header = headerOverride
|
||||||
log.Debugln("request Header: ", req.Header)
|
|
||||||
log.Debugln("request URL: ", URL)
|
|
||||||
res, err := HttpClient().Do(req)
|
res, err := HttpClient().Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
log.Debugf("response status: %d", res.StatusCode)
|
|
||||||
log.Debugln("response Header: ", res.Header)
|
|
||||||
// TODO clean header with blocklist or passlist
|
// TODO clean header with blocklist or passlist
|
||||||
res.Header.Del("set-cookie")
|
res.Header.Del("set-cookie")
|
||||||
if res.StatusCode >= 400 {
|
if res.StatusCode >= 400 {
|
||||||
|
@ -231,7 +227,6 @@ func RequestHttp(httpMethod string, headerOverride http.Header, URL string) (*ht
|
||||||
log.Debugln(msg)
|
log.Debugln(msg)
|
||||||
return res, errors.New(msg)
|
return res, errors.New(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -109,16 +109,11 @@ func ParseRange(s string, size int64) ([]Range, error) { // nolint:gocognit
|
||||||
|
|
||||||
func (r Range) MimeHeader(contentType string, size int64) textproto.MIMEHeader {
|
func (r Range) MimeHeader(contentType string, size int64) textproto.MIMEHeader {
|
||||||
return textproto.MIMEHeader{
|
return textproto.MIMEHeader{
|
||||||
"Content-Range": {r.contentRange(size)},
|
"Content-Range": {r.ContentRange(size)},
|
||||||
"Content-Type": {contentType},
|
"Content-Type": {contentType},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// for http response header
|
|
||||||
func (r Range) contentRange(size int64) string {
|
|
||||||
return fmt.Sprintf("bytes %d-%d/%d", r.Start, r.Start+r.Length-1, size)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ApplyRangeToHttpHeader for http request header
|
// ApplyRangeToHttpHeader for http request header
|
||||||
func ApplyRangeToHttpHeader(p Range, headerRef http.Header) http.Header {
|
func ApplyRangeToHttpHeader(p Range, headerRef http.Header) http.Header {
|
||||||
header := headerRef
|
header := headerRef
|
||||||
|
|
Loading…
Reference in New Issue