diff --git a/drivers/smb/driver.go b/drivers/smb/driver.go index b3217a72..5afd0e8e 100644 --- a/drivers/smb/driver.go +++ b/drivers/smb/driver.go @@ -3,12 +3,16 @@ package smb import ( "context" "errors" + "io" + "net/http" "path/filepath" + "strconv" "strings" "time" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/http_range" "github.com/alist-org/alist/v3/pkg/utils" "github.com/hirochachacha/go-smb2" @@ -79,10 +83,27 @@ func (d *SMB) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*m d.cleanLastConnTime() return nil, err } - d.updateLastConnTime() - return &model.Link{ + link := &model.Link{ Data: remoteFile, - }, nil + } + if args.Header.Get("Range") != "" { + r, err := http_range.ParseRange(args.Header.Get("Range"), file.GetSize()) + if err == nil && len(r) > 0 { + _, err := remoteFile.Seek(r[0].Start, io.SeekStart) + if err == nil { + link.Data = utils.NewLimitReadCloser(remoteFile, func() error { + return remoteFile.Close() + }, r[0].Length) + link.Status = 206 + link.Header = http.Header{ + "Content-Range": []string{r[0].ContentRange(file.GetSize())}, + "Content-Length": []string{strconv.FormatInt(r[0].Length, 10)}, + } + } + } + } + d.updateLastConnTime() + return link, nil } func (d *SMB) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { diff --git a/pkg/http_range/range.go b/pkg/http_range/range.go new file mode 100644 index 00000000..9af473fc --- /dev/null +++ b/pkg/http_range/range.go @@ -0,0 +1,107 @@ +// Package http_range implements http range parsing. +package http_range + +import ( + "errors" + "fmt" + "net/textproto" + "strconv" + "strings" +) + +// Range specifies the byte range to be sent to the client. +type Range struct { + Start int64 + Length int64 +} + +// ContentRange returns Content-Range header value. +func (r Range) ContentRange(size int64) string { + return fmt.Sprintf("bytes %d-%d/%d", r.Start, r.Start+r.Length-1, size) +} + +var ( + // ErrNoOverlap is returned by ParseRange if first-byte-pos of + // all of the byte-range-spec values is greater than the content size. + ErrNoOverlap = errors.New("invalid range: failed to overlap") + + // ErrInvalid is returned by ParseRange on invalid input. + ErrInvalid = errors.New("invalid range") +) + +// ParseRange parses a Range header string as per RFC 7233. +// ErrNoOverlap is returned if none of the ranges overlap. +// ErrInvalid is returned if s is invalid range. +func ParseRange(s string, size int64) ([]Range, error) { // nolint:gocognit + if s == "" { + return nil, nil // header not present + } + const b = "bytes=" + if !strings.HasPrefix(s, b) { + return nil, ErrInvalid + } + var ranges []Range + noOverlap := false + for _, ra := range strings.Split(s[len(b):], ",") { + ra = textproto.TrimString(ra) + if ra == "" { + continue + } + i := strings.Index(ra, "-") + if i < 0 { + return nil, ErrInvalid + } + start, end := textproto.TrimString(ra[:i]), textproto.TrimString(ra[i+1:]) + var r Range + if start == "" { + // If no start is specified, end specifies the + // range start relative to the end of the file, + // and we are dealing with + // which has to be a non-negative integer as per + // RFC 7233 Section 2.1 "Byte-Ranges". + if end == "" || end[0] == '-' { + return nil, ErrInvalid + } + i, err := strconv.ParseInt(end, 10, 64) + if i < 0 || err != nil { + return nil, ErrInvalid + } + if i > size { + i = size + } + r.Start = size - i + r.Length = size - r.Start + } else { + i, err := strconv.ParseInt(start, 10, 64) + if err != nil || i < 0 { + return nil, ErrInvalid + } + if i >= size { + // If the range begins after the size of the content, + // then it does not overlap. + noOverlap = true + continue + } + r.Start = i + if end == "" { + // If no end is specified, range extends to end of the file. + r.Length = size - r.Start + } else { + i, err := strconv.ParseInt(end, 10, 64) + if err != nil || r.Start > i { + return nil, ErrInvalid + } + if i >= size { + i = size - 1 + } + r.Length = i - r.Start + 1 + } + } + ranges = append(ranges, r) + } + if noOverlap && len(ranges) == 0 { + // The specified ranges did not overlap with the content. + return nil, ErrNoOverlap + } + return ranges, nil +} diff --git a/pkg/utils/io.go b/pkg/utils/io.go index 38450248..76b514a5 100644 --- a/pkg/utils/io.go +++ b/pkg/utils/io.go @@ -69,3 +69,25 @@ func (l limitWriter) Write(p []byte) (n int, err error) { func LimitWriter(w io.Writer, size int64) io.Writer { return &limitWriter{w: w, limit: size} } + +type ReadCloser struct { + io.Reader + io.Closer +} + +type CloseFunc func() error + +func (c CloseFunc) Close() error { + return c() +} + +func NewReadCloser(reader io.Reader, close CloseFunc) io.ReadCloser { + return ReadCloser{ + Reader: reader, + Closer: close, + } +} + +func NewLimitReadCloser(reader io.Reader, close CloseFunc, limit int64) io.ReadCloser { + return NewReadCloser(io.LimitReader(reader, limit), close) +}