enhance(download): Use just-in-time host in download URl, instead of SiteURL in site settings

pull/1741/head
Aaron Liu 2023-05-25 19:49:32 +08:00
parent 4c834e75fa
commit 4aafe1dc7a
16 changed files with 36 additions and 137 deletions

View File

@ -150,14 +150,7 @@ func (handler Driver) CORS() error {
// Get 获取文件 // Get 获取文件
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 获取文件源地址 // 获取文件源地址
downloadURL, err := handler.Source( downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
ctx,
path,
url.URL{},
int64(model.GetIntSetting("preview_timeout", 60)),
false,
0,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -267,14 +260,7 @@ func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.Co
} }
// Source 获取外链URL // Source 获取外链URL
func (handler Driver) Source( func (handler Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
ctx context.Context,
path string,
baseURL url.URL,
ttl int64,
isDownload bool,
speed int,
) (string, error) {
// 尝试从上下文获取文件名 // 尝试从上下文获取文件名
fileName := "" fileName := ""
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {

View File

@ -8,7 +8,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"net/url"
) )
// Driver Google Drive 适配器 // Driver Google Drive 适配器
@ -45,7 +44,7 @@ func (d *Driver) Thumb(ctx context.Context, file *model.File) (*response.Content
panic("implement me") panic("implement me")
} }
func (d *Driver) Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error) { func (d *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")
} }

View File

@ -7,7 +7,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"net/url"
) )
var ( var (
@ -37,7 +36,7 @@ type Handler interface {
// 获取外链/下载地址, // 获取外链/下载地址,
// url - 站点本身地址, // url - 站点本身地址,
// isDownload - 是否直接下载 // isDownload - 是否直接下载
Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error)
// Token 获取有效期为ttl的上传凭证和签名 // Token 获取有效期为ttl的上传凭证和签名
Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error)

View File

@ -219,26 +219,20 @@ func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.Co
} }
// Source 获取外链URL // Source 获取外链URL
func (handler Driver) Source( func (handler Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
ctx context.Context,
path string,
baseURL url.URL,
ttl int64,
isDownload bool,
speed int,
) (string, error) {
file, ok := ctx.Value(fsctx.FileModelCtx).(model.File) file, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
if !ok { if !ok {
return "", errors.New("failed to read file model context") return "", errors.New("failed to read file model context")
} }
var baseURL *url.URL
// 是否启用了CDN // 是否启用了CDN
if handler.Policy.BaseURL != "" { if handler.Policy.BaseURL != "" {
cdnURL, err := url.Parse(handler.Policy.BaseURL) cdnURL, err := url.Parse(handler.Policy.BaseURL)
if err != nil { if err != nil {
return "", err return "", err
} }
baseURL = *cdnURL baseURL = cdnURL
} }
var ( var (
@ -272,7 +266,11 @@ func (handler Driver) Source(
return "", serializer.NewError(serializer.CodeEncryptError, "Failed to sign url", err) return "", serializer.NewError(serializer.CodeEncryptError, "Failed to sign url", err)
} }
finalURL := baseURL.ResolveReference(signedURI).String() finalURL := signedURI.String()
if baseURL != nil {
finalURL = baseURL.ResolveReference(signedURI).String()
}
return finalURL, nil return finalURL, nil
} }

View File

@ -91,7 +91,6 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser,
downloadURL, err := handler.Source( downloadURL, err := handler.Source(
ctx, ctx,
path, path,
url.URL{},
60, 60,
false, false,
0, 0,
@ -164,7 +163,6 @@ func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.Co
func (handler Driver) Source( func (handler Driver) Source(
ctx context.Context, ctx context.Context,
path string, path string,
baseURL url.URL,
ttl int64, ttl int64,
isDownload bool, isDownload bool,
speed int, speed int,

View File

@ -9,7 +9,6 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -106,7 +105,7 @@ func TestDriver_Source(t *testing.T) {
// 失败 // 失败
{ {
res, err := handler.Source(context.Background(), "123.jpg", url.URL{}, 1, true, 0) res, err := handler.Source(context.Background(), "123.jpg", 1, true, 0)
asserts.Error(err) asserts.Error(err)
asserts.Empty(res) asserts.Empty(res)
} }
@ -116,7 +115,7 @@ func TestDriver_Source(t *testing.T) {
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
handler.Client.Credential.AccessToken = "1" handler.Client.Credential.AccessToken = "1"
cache.Set("onedrive_source_0_123.jpg", "res", 1) cache.Set("onedrive_source_0_123.jpg", "res", 1)
res, err := handler.Source(context.Background(), "123.jpg", url.URL{}, 0, true, 0) res, err := handler.Source(context.Background(), "123.jpg", 0, true, 0)
cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_") cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_")
asserts.NoError(err) asserts.NoError(err)
asserts.Equal("res", res) asserts.Equal("res", res)
@ -131,7 +130,7 @@ func TestDriver_Source(t *testing.T) {
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
handler.Client.Credential.AccessToken = "1" handler.Client.Credential.AccessToken = "1"
cache.Set(fmt.Sprintf("onedrive_source_file_%d_1", file.UpdatedAt.Unix()), "res", 0) cache.Set(fmt.Sprintf("onedrive_source_file_%d_1", file.UpdatedAt.Unix()), "res", 0)
res, err := handler.Source(ctx, "123.jpg", url.URL{}, 1, true, 0) res, err := handler.Source(ctx, "123.jpg", 1, true, 0)
cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_") cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_")
asserts.NoError(err) asserts.NoError(err)
asserts.Equal("res", res) asserts.Equal("res", res)
@ -156,7 +155,7 @@ func TestDriver_Source(t *testing.T) {
}) })
handler.Client.Request = clientMock handler.Client.Request = clientMock
handler.Client.Credential.AccessToken = "1" handler.Client.Credential.AccessToken = "1"
res, err := handler.Source(context.Background(), "123.jpg", url.URL{}, 1, true, 0) res, err := handler.Source(context.Background(), "123.jpg", 1, true, 0)
asserts.NoError(err) asserts.NoError(err)
asserts.Equal("123321", res) asserts.Equal("123321", res)
} }

View File

@ -194,14 +194,7 @@ func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser,
ctx = context.WithValue(ctx, fsctx.ForceUsePublicEndpointCtx, false) ctx = context.WithValue(ctx, fsctx.ForceUsePublicEndpointCtx, false)
// 获取文件源地址 // 获取文件源地址
downloadURL, err := handler.Source( downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
ctx,
path,
url.URL{},
int64(model.GetIntSetting("preview_timeout", 60)),
false,
0,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -339,14 +332,7 @@ func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.C
} }
// Source 获取外链URL // Source 获取外链URL
func (handler *Driver) Source( func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
ctx context.Context,
path string,
baseURL url.URL,
ttl int64,
isDownload bool,
speed int,
) (string, error) {
// 初始化客户端 // 初始化客户端
usePublicEndpoint := true usePublicEndpoint := true
if forceUsePublicEndpoint, ok := ctx.Value(fsctx.ForceUsePublicEndpointCtx).(bool); ok { if forceUsePublicEndpoint, ok := ctx.Value(fsctx.ForceUsePublicEndpointCtx).(bool); ok {

View File

@ -119,14 +119,7 @@ func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser,
path = fmt.Sprintf("%s?v=%d", path, time.Now().UnixNano()) path = fmt.Sprintf("%s?v=%d", path, time.Now().UnixNano())
// 获取文件源地址 // 获取文件源地址
downloadURL, err := handler.Source( downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
ctx,
path,
url.URL{},
int64(model.GetIntSetting("preview_timeout", 60)),
false,
0,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -264,14 +257,7 @@ func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.C
} }
// Source 获取外链URL // Source 获取外链URL
func (handler *Driver) Source( func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
ctx context.Context,
path string,
baseURL url.URL,
ttl int64,
isDownload bool,
speed int,
) (string, error) {
// 尝试从上下文获取文件名 // 尝试从上下文获取文件名
fileName := "" fileName := ""
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {

View File

@ -124,7 +124,7 @@ func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser,
} }
// 获取文件源地址 // 获取文件源地址
downloadURL, err := handler.Source(ctx, path, url.URL{}, 0, true, speedLimit) downloadURL, err := handler.Source(ctx, path, 0, true, speedLimit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -233,14 +233,7 @@ func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.C
} }
// Source 获取外链URL // Source 获取外链URL
func (handler *Driver) Source( func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
ctx context.Context,
path string,
baseURL url.URL,
ttl int64,
isDownload bool,
speed int,
) (string, error) {
// 尝试从上下文获取文件名 // 尝试从上下文获取文件名
fileName := "file" fileName := "file"
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {

View File

@ -9,7 +9,6 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url"
"strings" "strings"
"testing" "testing"
@ -51,7 +50,7 @@ func TestHandler_Source(t *testing.T) {
AuthInstance: auth.HMACAuth{}, AuthInstance: auth.HMACAuth{},
} }
ctx := context.Background() ctx := context.Background()
res, err := handler.Source(ctx, "", url.URL{}, 0, true, 0) res, err := handler.Source(ctx, "", 0, true, 0)
asserts.NoError(err) asserts.NoError(err)
asserts.NotEmpty(res) asserts.NotEmpty(res)
} }
@ -66,7 +65,7 @@ func TestHandler_Source(t *testing.T) {
SourceName: "1.txt", SourceName: "1.txt",
} }
ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
res, err := handler.Source(ctx, "", url.URL{}, 10, true, 0) res, err := handler.Source(ctx, "", 10, true, 0)
asserts.NoError(err) asserts.NoError(err)
asserts.Contains(res, "api/v3/slave/download/0") asserts.Contains(res, "api/v3/slave/download/0")
} }
@ -81,7 +80,7 @@ func TestHandler_Source(t *testing.T) {
SourceName: "1.txt", SourceName: "1.txt",
} }
ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
res, err := handler.Source(ctx, "", url.URL{}, 10, true, 0) res, err := handler.Source(ctx, "", 10, true, 0)
asserts.NoError(err) asserts.NoError(err)
asserts.Contains(res, "api/v3/slave/download/0") asserts.Contains(res, "api/v3/slave/download/0")
asserts.Contains(res, "https://cqu.edu.cn") asserts.Contains(res, "https://cqu.edu.cn")
@ -97,7 +96,7 @@ func TestHandler_Source(t *testing.T) {
SourceName: "1.txt", SourceName: "1.txt",
} }
ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
res, err := handler.Source(ctx, "", url.URL{}, 10, true, 0) res, err := handler.Source(ctx, "", 10, true, 0)
asserts.Error(err) asserts.Error(err)
asserts.Empty(res) asserts.Empty(res)
} }
@ -112,7 +111,7 @@ func TestHandler_Source(t *testing.T) {
SourceName: "1.txt", SourceName: "1.txt",
} }
ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
res, err := handler.Source(ctx, "", url.URL{}, 10, false, 0) res, err := handler.Source(ctx, "", 10, false, 0)
asserts.NoError(err) asserts.NoError(err)
asserts.Contains(res, "api/v3/slave/source/0") asserts.Contains(res, "api/v3/slave/source/0")
} }

View File

@ -164,14 +164,7 @@ func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([
// Get 获取文件 // Get 获取文件
func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 获取文件源地址 // 获取文件源地址
downloadURL, err := handler.Source( downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
ctx,
path,
url.URL{},
int64(model.GetIntSetting("preview_timeout", 60)),
false,
0,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -270,14 +263,7 @@ func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.C
} }
// Source 获取外链URL // Source 获取外链URL
func (handler *Driver) Source( func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
ctx context.Context,
path string,
baseURL url.URL,
ttl int64,
isDownload bool,
speed int,
) (string, error) {
// 尝试从上下文获取文件名 // 尝试从上下文获取文件名
fileName := "" fileName := ""

View File

@ -8,7 +8,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"net/url"
) )
// Driver 影子存储策略,用于在从机端上传文件 // Driver 影子存储策略,用于在从机端上传文件
@ -43,7 +42,7 @@ func (d *Driver) Thumb(ctx context.Context, file *model.File) (*response.Content
return nil, ErrNotImplemented return nil, ErrNotImplemented
} }
func (d *Driver) Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error) { func (d *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
return "", ErrNotImplemented return "", ErrNotImplemented
} }

View File

@ -106,7 +106,7 @@ func (d *Driver) Thumb(ctx context.Context, file *model.File) (*response.Content
return nil, ErrNotImplemented return nil, ErrNotImplemented
} }
func (d *Driver) Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error) { func (d *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
return "", ErrNotImplemented return "", ErrNotImplemented
} }

View File

@ -107,14 +107,7 @@ func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]
// Get 获取文件 // Get 获取文件
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 获取文件源地址 // 获取文件源地址
downloadURL, err := handler.Source( downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
ctx,
path,
url.URL{},
int64(model.GetIntSetting("preview_timeout", 60)),
false,
0,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -243,14 +236,7 @@ func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.Co
} }
thumbParam := fmt.Sprintf("!/fwfh/%dx%d", thumbSize[0], thumbSize[1]) thumbParam := fmt.Sprintf("!/fwfh/%dx%d", thumbSize[0], thumbSize[1])
thumbURL, err := handler.Source( thumbURL, err := handler.Source(ctx, file.SourceName+thumbParam, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
ctx,
file.SourceName+thumbParam,
url.URL{},
int64(model.GetIntSetting("preview_timeout", 60)),
false,
0,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -262,14 +248,7 @@ func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.Co
} }
// Source 获取外链URL // Source 获取外链URL
func (handler Driver) Source( func (handler Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
ctx context.Context,
path string,
baseURL url.URL,
ttl int64,
isDownload bool,
speed int,
) (string, error) {
// 尝试从上下文获取文件名 // 尝试从上下文获取文件名
fileName := "" fileName := ""
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {

View File

@ -300,8 +300,7 @@ func (fs *FileSystem) SignURL(ctx context.Context, file *model.File, ttl int64,
// 签名最终URL // 签名最终URL
// 生成外链地址 // 生成外链地址
siteURL := model.GetSiteURL() source, err := fs.Handler.Source(ctx, fs.FileTarget[0].SourceName, ttl, isDownload, fs.User.Group.SpeedLimit)
source, err := fs.Handler.Source(ctx, fs.FileTarget[0].SourceName, *siteURL, ttl, isDownload, fs.User.Group.SpeedLimit)
if err != nil { if err != nil {
return "", serializer.NewError(serializer.CodeNotSet, "Failed to get source link", err) return "", serializer.NewError(serializer.CodeNotSet, "Failed to get source link", err)
} }

View File

@ -57,14 +57,7 @@ func (fs *FileSystem) GetThumb(ctx context.Context, id uint) (*response.ContentR
res = &response.ContentResponse{ res = &response.ContentResponse{
Redirect: true, Redirect: true,
} }
res.URL, err = fs.Handler.Source( res.URL, err = fs.Handler.Source(ctx, file.ThumbFile(), int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
ctx,
file.ThumbFile(),
*model.GetSiteURL(),
int64(model.GetIntSetting("preview_timeout", 60)),
false,
0,
)
} else { } else {
// if not exist, generate and upload the sidecar thumb. // if not exist, generate and upload the sidecar thumb.
if err = fs.generateThumbnail(ctx, &file); err == nil { if err = fs.generateThumbnail(ctx, &file); err == nil {