fix: copy tasks using multi-thread downloader can't be canceled (#5028)

#4981 related
pull/5042/head
Andy Hsu 2023-08-19 14:06:59 +08:00 committed by GitHub
parent ed550594da
commit 1e3950c847
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 51 additions and 137 deletions

View File

@ -1,6 +1,7 @@
package net
import (
"bytes"
"context"
"fmt"
"io"
@ -202,7 +203,6 @@ func (d *downloader) downloadPart() {
//defer d.wg.Done()
for {
c, ok := <-d.chunkChannel
log.Debugf("downloadPart tried to get chunk")
if !ok {
break
}
@ -211,7 +211,7 @@ func (d *downloader) downloadPart() {
// of download producer.
continue
}
log.Debugf("downloadPart tried to get chunk")
if err := d.downloadChunk(&c); err != nil {
d.setErr(err)
}
@ -220,7 +220,7 @@ func (d *downloader) downloadPart() {
// downloadChunk downloads the chunk
func (d *downloader) downloadChunk(ch *chunk) error {
log.Debugf("start new chunk %+v buffer_id =%d", ch, ch.buf.buffer.id)
log.Debugf("start new chunk %+v buffer_id =%d", ch, ch.id)
var n int64
var err error
params := d.getParamsFromChunk(ch)
@ -262,6 +262,7 @@ func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int
if err != nil {
return 0, err
}
defer resp.Body.Close()
//only check file size on the first task
if ch.id == 0 {
err = d.checkTotalBytes(resp)
@ -279,7 +280,6 @@ func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int
err = fmt.Errorf("chunk download size incorrect, expected=%d, got=%d", ch.size, n)
return n, &errReadingBody{err: err}
}
defer resp.Body.Close()
return n, nil
}
@ -402,13 +402,8 @@ func (e *errReadingBody) Unwrap() error {
}
type MultiReadCloser struct {
io.ReadCloser
//total int //total bufArr
//wPos int //current reader wPos
cfg *cfg
closer closerFunc
//getBuf getBufFunc
finish finishBufFUnc
}
@ -449,99 +444,26 @@ func (mr MultiReadCloser) Close() error {
return mr.closer()
}
type Buffer struct {
data []byte
wPos int //writer position
id int
rPos int //reader position
lock sync.Mutex
once bool //combined use with notify & lock, to get notify once
notify chan int // notifies new writes
}
func (buf *Buffer) Write(p []byte) (n int, err error) {
inSize := len(p)
if inSize == 0 {
return 0, nil
}
if inSize > len(buf.data)-buf.wPos {
return 0, fmt.Errorf("exceeding buffer max size,inSize=%d ,buf.data.len=%d , buf.wPos=%d",
inSize, len(buf.data), buf.wPos)
}
copy(buf.data[buf.wPos:], p)
buf.wPos += inSize
//give read a notice if once==true
buf.lock.Lock()
if buf.once == true {
buf.notify <- inSize //struct{}{}
}
buf.once = false
buf.lock.Unlock()
return inSize, nil
}
func (buf *Buffer) getPos() (n int) {
return buf.wPos
}
func (buf *Buffer) reset() {
buf.wPos = 0
buf.rPos = 0
}
// waitTillNewWrite notify caller that new write happens
func (buf *Buffer) waitTillNewWrite(pos int) error {
//log.Debugf("waitTillNewWrite, current wPos=%d", pos)
var err error
//defer buffer.lock.Unlock()
if pos >= len(buf.data) {
err = fmt.Errorf("there will not be any new write")
} else if pos > buf.wPos {
err = fmt.Errorf("illegal read position")
} else if pos == buf.wPos {
buf.lock.Lock()
buf.once = true
//buffer.wg1.Add(1)
buf.lock.Unlock()
//wait for write
log.Debugf("waitTillNewWrite wait for notify")
writes := <-buf.notify
log.Debugf("waitTillNewWrite got new write from notify, last writes:%+v", writes)
//if pos >= buf.wPos {
// //wrote 0 bytes
// return fmt.Errorf("write has error")
//}
return nil
}
//only case: wPos < buffer.wPos
return err
}
type Buf struct {
buffer *Buffer // Buffer we read from
size int //expected size
buffer *bytes.Buffer
size int //expected size
ctx context.Context
off int
rw sync.RWMutex
notify chan struct{}
}
// NewBuf is a buffer that can have 1 read & 1 write at the same time.
// when read is faster write, immediately feed data to read after written
func NewBuf(ctx context.Context, maxSize int, id int) *Buf {
d := make([]byte, maxSize)
buffer := &Buffer{data: d, id: id, notify: make(chan int)}
buffer.reset()
return &Buf{ctx: ctx, buffer: buffer, size: maxSize}
return &Buf{ctx: ctx, buffer: bytes.NewBuffer(d), size: maxSize, notify: make(chan struct{})}
}
func (br *Buf) Reset(size int) {
br.buffer.reset()
br.buffer.Reset()
br.size = size
}
func (br *Buf) GetId() int {
return br.buffer.id
br.off = 0
}
func (br *Buf) Read(p []byte) (n int, err error) {
@ -551,48 +473,49 @@ func (br *Buf) Read(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
if br.buffer.rPos == br.size {
if br.off >= br.size {
return 0, io.EOF
}
//persist buffer position as another thread is keep increasing it
bufPos := br.buffer.getPos()
outSize := bufPos - br.buffer.rPos
if outSize == 0 {
//var wg sync.WaitGroup
err := br.waitTillNewWrite(br.buffer.rPos)
if err != nil {
return 0, err
}
bufPos = br.buffer.getPos()
outSize = bufPos - br.buffer.rPos
br.rw.RLock()
n, err = br.buffer.Read(p)
br.rw.RUnlock()
if err == nil {
br.off += n
return n, err
}
if len(p) < outSize {
// p is not big enough
outSize = len(p)
if err != io.EOF {
return n, err
}
copy(p, br.buffer.data[br.buffer.rPos:br.buffer.rPos+outSize])
br.buffer.rPos += outSize
if br.buffer.rPos == br.size {
err = io.EOF
if n != 0 {
br.off += n
return n, nil
}
// n==0, err==io.EOF
// wait for new write for 200ms
select {
case <-br.ctx.Done():
return 0, br.ctx.Err()
case <-br.notify:
return 0, nil
case <-time.After(time.Millisecond * 200):
return 0, nil
}
return outSize, err
}
// waitTillNewWrite is expensive, since we just checked that no new data, wait 0.2s
func (br *Buf) waitTillNewWrite(pos int) error {
time.Sleep(200 * time.Millisecond)
return br.buffer.waitTillNewWrite(br.buffer.rPos)
}
func (br *Buf) Write(p []byte) (n int, err error) {
if err := br.ctx.Err(); err != nil {
return 0, err
}
return br.buffer.Write(p)
br.rw.Lock()
defer br.rw.Unlock()
n, err = br.buffer.Write(p)
select {
case br.notify <- struct{}{}:
default:
}
return
}
func (br *Buf) Close() {
close(br.buffer.notify)
close(br.notify)
}

View File

@ -7,14 +7,15 @@ import (
"bytes"
"context"
"fmt"
"github.com/alist-org/alist/v3/pkg/http_range"
"github.com/sirupsen/logrus"
"golang.org/x/exp/slices"
"io"
"io/ioutil"
"net/http"
"sync"
"testing"
"github.com/alist-org/alist/v3/pkg/http_range"
"github.com/sirupsen/logrus"
"golang.org/x/exp/slices"
)
var buf22MB = make([]byte, 1024*1024*22)
@ -55,7 +56,7 @@ func TestDownloadOrder(t *testing.T) {
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
resultBuf, err := io.ReadAll(*readCloser)
resultBuf, err := io.ReadAll(readCloser)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
@ -111,7 +112,7 @@ func TestDownloadSingle(t *testing.T) {
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
resultBuf, err := io.ReadAll(*readCloser)
resultBuf, err := io.ReadAll(readCloser)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}

View File

@ -215,14 +215,10 @@ func RequestHttp(httpMethod string, headerOverride http.Header, URL string) (*ht
return nil, err
}
req.Header = headerOverride
log.Debugln("request Header: ", req.Header)
log.Debugln("request URL: ", URL)
res, err := HttpClient().Do(req)
if err != nil {
return nil, err
}
log.Debugf("response status: %d", res.StatusCode)
log.Debugln("response Header: ", res.Header)
// TODO clean header with blocklist or passlist
res.Header.Del("set-cookie")
if res.StatusCode >= 400 {
@ -231,7 +227,6 @@ func RequestHttp(httpMethod string, headerOverride http.Header, URL string) (*ht
log.Debugln(msg)
return res, errors.New(msg)
}
return res, nil
}

View File

@ -109,16 +109,11 @@ func ParseRange(s string, size int64) ([]Range, error) { // nolint:gocognit
func (r Range) MimeHeader(contentType string, size int64) textproto.MIMEHeader {
return textproto.MIMEHeader{
"Content-Range": {r.contentRange(size)},
"Content-Range": {r.ContentRange(size)},
"Content-Type": {contentType},
}
}
// for http response header
func (r Range) contentRange(size int64) string {
return fmt.Sprintf("bytes %d-%d/%d", r.Start, r.Start+r.Length-1, size)
}
// ApplyRangeToHttpHeader for http request header
func ApplyRangeToHttpHeader(p Range, headerRef http.Header) http.Header {
header := headerRef