mirror of https://github.com/Xhofe/alist
fix: copy tasks using multi-thread downloader can't be canceled (#5028)
#4981 relatedpull/5042/head
parent
ed550594da
commit
1e3950c847
|
@ -1,6 +1,7 @@
|
|||
package net
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -202,7 +203,6 @@ func (d *downloader) downloadPart() {
|
|||
//defer d.wg.Done()
|
||||
for {
|
||||
c, ok := <-d.chunkChannel
|
||||
log.Debugf("downloadPart tried to get chunk")
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
@ -211,7 +211,7 @@ func (d *downloader) downloadPart() {
|
|||
// of download producer.
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("downloadPart tried to get chunk")
|
||||
if err := d.downloadChunk(&c); err != nil {
|
||||
d.setErr(err)
|
||||
}
|
||||
|
@ -220,7 +220,7 @@ func (d *downloader) downloadPart() {
|
|||
|
||||
// downloadChunk downloads the chunk
|
||||
func (d *downloader) downloadChunk(ch *chunk) error {
|
||||
log.Debugf("start new chunk %+v buffer_id =%d", ch, ch.buf.buffer.id)
|
||||
log.Debugf("start new chunk %+v buffer_id =%d", ch, ch.id)
|
||||
var n int64
|
||||
var err error
|
||||
params := d.getParamsFromChunk(ch)
|
||||
|
@ -262,6 +262,7 @@ func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int
|
|||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
//only check file size on the first task
|
||||
if ch.id == 0 {
|
||||
err = d.checkTotalBytes(resp)
|
||||
|
@ -279,7 +280,6 @@ func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int
|
|||
err = fmt.Errorf("chunk download size incorrect, expected=%d, got=%d", ch.size, n)
|
||||
return n, &errReadingBody{err: err}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
@ -402,13 +402,8 @@ func (e *errReadingBody) Unwrap() error {
|
|||
}
|
||||
|
||||
type MultiReadCloser struct {
|
||||
io.ReadCloser
|
||||
|
||||
//total int //total bufArr
|
||||
//wPos int //current reader wPos
|
||||
cfg *cfg
|
||||
closer closerFunc
|
||||
//getBuf getBufFunc
|
||||
finish finishBufFUnc
|
||||
}
|
||||
|
||||
|
@ -449,99 +444,26 @@ func (mr MultiReadCloser) Close() error {
|
|||
return mr.closer()
|
||||
}
|
||||
|
||||
type Buffer struct {
|
||||
data []byte
|
||||
wPos int //writer position
|
||||
id int
|
||||
rPos int //reader position
|
||||
lock sync.Mutex
|
||||
|
||||
once bool //combined use with notify & lock, to get notify once
|
||||
notify chan int // notifies new writes
|
||||
}
|
||||
|
||||
func (buf *Buffer) Write(p []byte) (n int, err error) {
|
||||
inSize := len(p)
|
||||
if inSize == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if inSize > len(buf.data)-buf.wPos {
|
||||
return 0, fmt.Errorf("exceeding buffer max size,inSize=%d ,buf.data.len=%d , buf.wPos=%d",
|
||||
inSize, len(buf.data), buf.wPos)
|
||||
}
|
||||
copy(buf.data[buf.wPos:], p)
|
||||
buf.wPos += inSize
|
||||
|
||||
//give read a notice if once==true
|
||||
buf.lock.Lock()
|
||||
if buf.once == true {
|
||||
buf.notify <- inSize //struct{}{}
|
||||
}
|
||||
buf.once = false
|
||||
buf.lock.Unlock()
|
||||
|
||||
return inSize, nil
|
||||
}
|
||||
|
||||
func (buf *Buffer) getPos() (n int) {
|
||||
return buf.wPos
|
||||
}
|
||||
func (buf *Buffer) reset() {
|
||||
buf.wPos = 0
|
||||
buf.rPos = 0
|
||||
}
|
||||
|
||||
// waitTillNewWrite notify caller that new write happens
|
||||
func (buf *Buffer) waitTillNewWrite(pos int) error {
|
||||
//log.Debugf("waitTillNewWrite, current wPos=%d", pos)
|
||||
var err error
|
||||
|
||||
//defer buffer.lock.Unlock()
|
||||
if pos >= len(buf.data) {
|
||||
err = fmt.Errorf("there will not be any new write")
|
||||
} else if pos > buf.wPos {
|
||||
err = fmt.Errorf("illegal read position")
|
||||
} else if pos == buf.wPos {
|
||||
buf.lock.Lock()
|
||||
buf.once = true
|
||||
//buffer.wg1.Add(1)
|
||||
buf.lock.Unlock()
|
||||
//wait for write
|
||||
log.Debugf("waitTillNewWrite wait for notify")
|
||||
writes := <-buf.notify
|
||||
log.Debugf("waitTillNewWrite got new write from notify, last writes:%+v", writes)
|
||||
//if pos >= buf.wPos {
|
||||
// //wrote 0 bytes
|
||||
// return fmt.Errorf("write has error")
|
||||
//}
|
||||
return nil
|
||||
}
|
||||
//only case: wPos < buffer.wPos
|
||||
return err
|
||||
}
|
||||
|
||||
type Buf struct {
|
||||
buffer *Buffer // Buffer we read from
|
||||
size int //expected size
|
||||
buffer *bytes.Buffer
|
||||
size int //expected size
|
||||
ctx context.Context
|
||||
off int
|
||||
rw sync.RWMutex
|
||||
notify chan struct{}
|
||||
}
|
||||
|
||||
// NewBuf is a buffer that can have 1 read & 1 write at the same time.
|
||||
// when read is faster write, immediately feed data to read after written
|
||||
func NewBuf(ctx context.Context, maxSize int, id int) *Buf {
|
||||
d := make([]byte, maxSize)
|
||||
buffer := &Buffer{data: d, id: id, notify: make(chan int)}
|
||||
buffer.reset()
|
||||
return &Buf{ctx: ctx, buffer: buffer, size: maxSize}
|
||||
return &Buf{ctx: ctx, buffer: bytes.NewBuffer(d), size: maxSize, notify: make(chan struct{})}
|
||||
|
||||
}
|
||||
func (br *Buf) Reset(size int) {
|
||||
br.buffer.reset()
|
||||
br.buffer.Reset()
|
||||
br.size = size
|
||||
}
|
||||
func (br *Buf) GetId() int {
|
||||
return br.buffer.id
|
||||
br.off = 0
|
||||
}
|
||||
|
||||
func (br *Buf) Read(p []byte) (n int, err error) {
|
||||
|
@ -551,48 +473,49 @@ func (br *Buf) Read(p []byte) (n int, err error) {
|
|||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if br.buffer.rPos == br.size {
|
||||
if br.off >= br.size {
|
||||
return 0, io.EOF
|
||||
}
|
||||
//persist buffer position as another thread is keep increasing it
|
||||
bufPos := br.buffer.getPos()
|
||||
outSize := bufPos - br.buffer.rPos
|
||||
|
||||
if outSize == 0 {
|
||||
//var wg sync.WaitGroup
|
||||
err := br.waitTillNewWrite(br.buffer.rPos)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
bufPos = br.buffer.getPos()
|
||||
outSize = bufPos - br.buffer.rPos
|
||||
br.rw.RLock()
|
||||
n, err = br.buffer.Read(p)
|
||||
br.rw.RUnlock()
|
||||
if err == nil {
|
||||
br.off += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
if len(p) < outSize {
|
||||
// p is not big enough
|
||||
outSize = len(p)
|
||||
if err != io.EOF {
|
||||
return n, err
|
||||
}
|
||||
copy(p, br.buffer.data[br.buffer.rPos:br.buffer.rPos+outSize])
|
||||
br.buffer.rPos += outSize
|
||||
if br.buffer.rPos == br.size {
|
||||
err = io.EOF
|
||||
if n != 0 {
|
||||
br.off += n
|
||||
return n, nil
|
||||
}
|
||||
// n==0, err==io.EOF
|
||||
// wait for new write for 200ms
|
||||
select {
|
||||
case <-br.ctx.Done():
|
||||
return 0, br.ctx.Err()
|
||||
case <-br.notify:
|
||||
return 0, nil
|
||||
case <-time.After(time.Millisecond * 200):
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return outSize, err
|
||||
}
|
||||
|
||||
// waitTillNewWrite is expensive, since we just checked that no new data, wait 0.2s
|
||||
func (br *Buf) waitTillNewWrite(pos int) error {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
return br.buffer.waitTillNewWrite(br.buffer.rPos)
|
||||
}
|
||||
|
||||
func (br *Buf) Write(p []byte) (n int, err error) {
|
||||
if err := br.ctx.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return br.buffer.Write(p)
|
||||
br.rw.Lock()
|
||||
defer br.rw.Unlock()
|
||||
n, err = br.buffer.Write(p)
|
||||
select {
|
||||
case br.notify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (br *Buf) Close() {
|
||||
close(br.buffer.notify)
|
||||
close(br.notify)
|
||||
}
|
||||
|
|
|
@ -7,14 +7,15 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/alist-org/alist/v3/pkg/http_range"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/slices"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/alist-org/alist/v3/pkg/http_range"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
var buf22MB = make([]byte, 1024*1024*22)
|
||||
|
@ -55,7 +56,7 @@ func TestDownloadOrder(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
resultBuf, err := io.ReadAll(*readCloser)
|
||||
resultBuf, err := io.ReadAll(readCloser)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
@ -111,7 +112,7 @@ func TestDownloadSingle(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
resultBuf, err := io.ReadAll(*readCloser)
|
||||
resultBuf, err := io.ReadAll(readCloser)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
|
|
@ -215,14 +215,10 @@ func RequestHttp(httpMethod string, headerOverride http.Header, URL string) (*ht
|
|||
return nil, err
|
||||
}
|
||||
req.Header = headerOverride
|
||||
log.Debugln("request Header: ", req.Header)
|
||||
log.Debugln("request URL: ", URL)
|
||||
res, err := HttpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Debugf("response status: %d", res.StatusCode)
|
||||
log.Debugln("response Header: ", res.Header)
|
||||
// TODO clean header with blocklist or passlist
|
||||
res.Header.Del("set-cookie")
|
||||
if res.StatusCode >= 400 {
|
||||
|
@ -231,7 +227,6 @@ func RequestHttp(httpMethod string, headerOverride http.Header, URL string) (*ht
|
|||
log.Debugln(msg)
|
||||
return res, errors.New(msg)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -109,16 +109,11 @@ func ParseRange(s string, size int64) ([]Range, error) { // nolint:gocognit
|
|||
|
||||
func (r Range) MimeHeader(contentType string, size int64) textproto.MIMEHeader {
|
||||
return textproto.MIMEHeader{
|
||||
"Content-Range": {r.contentRange(size)},
|
||||
"Content-Range": {r.ContentRange(size)},
|
||||
"Content-Type": {contentType},
|
||||
}
|
||||
}
|
||||
|
||||
// for http response header
|
||||
func (r Range) contentRange(size int64) string {
|
||||
return fmt.Sprintf("bytes %d-%d/%d", r.Start, r.Start+r.Length-1, size)
|
||||
}
|
||||
|
||||
// ApplyRangeToHttpHeader for http request header
|
||||
func ApplyRangeToHttpHeader(p Range, headerRef http.Header) http.Header {
|
||||
header := headerRef
|
||||
|
|
Loading…
Reference in New Issue