diff --git a/pkg/filesystem/driver/oss/handler.go b/pkg/filesystem/driver/oss/handler.go index 9dd5353..fc4f9a3 100644 --- a/pkg/filesystem/driver/oss/handler.go +++ b/pkg/filesystem/driver/oss/handler.go @@ -186,7 +186,7 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, ctx = context.WithValue(ctx, VersionID, time.Now().UnixNano()) // 尽可能使用私有 Endpoint - ctx = context.WithValue(ctx, fsctx.ForceUsePublicEndpoint, false) + ctx = context.WithValue(ctx, fsctx.ForceUsePublicEndpointCtx, false) // 获取文件源地址 downloadURL, err := handler.Source( @@ -317,7 +317,7 @@ func (handler Driver) Source( ) (string, error) { // 初始化客户端 usePublicEndpoint := true - if forceUsePublicEndpoint, ok := ctx.Value(fsctx.ForceUsePublicEndpoint).(bool); ok { + if forceUsePublicEndpoint, ok := ctx.Value(fsctx.ForceUsePublicEndpointCtx).(bool); ok { usePublicEndpoint = forceUsePublicEndpoint } if err := handler.InitOSSClient(usePublicEndpoint); err != nil { diff --git a/pkg/filesystem/driver/oss/handler_test.go b/pkg/filesystem/driver/oss/handler_test.go index 267f8fc..3758bba 100644 --- a/pkg/filesystem/driver/oss/handler_test.go +++ b/pkg/filesystem/driver/oss/handler_test.go @@ -193,7 +193,7 @@ func TestDriver_Source(t *testing.T) { { handler.Policy.BaseURL = "" handler.Policy.OptionsSerialized.ServerSideEndpoint = "endpoint.com" - res, err := handler.Source(context.WithValue(context.Background(), fsctx.ForceUsePublicEndpoint, false), "/123", url.URL{}, 10, false, 0) + res, err := handler.Source(context.WithValue(context.Background(), fsctx.ForceUsePublicEndpointCtx, false), "/123", url.URL{}, 10, false, 0) asserts.NoError(err) resURL, err := url.Parse(res) asserts.NoError(err) diff --git a/pkg/filesystem/fsctx/context.go b/pkg/filesystem/fsctx/context.go index 6b3f81d..0056043 100644 --- a/pkg/filesystem/fsctx/context.go +++ b/pkg/filesystem/fsctx/context.go @@ -33,6 +33,10 @@ const ( IgnoreConflictCtx // RetryCtx 失败重试次数 RetryCtx - // ForceUsePublicEndpoint 强制使用公网 Endpoint - ForceUsePublicEndpoint + // ForceUsePublicEndpointCtx 强制使用公网 Endpoint + ForceUsePublicEndpointCtx + // CancelFuncCtx Context 取消函數 + CancelFuncCtx + // ValidateCapacityOnceCtx 限定归还容量的操作只執行一次 + ValidateCapacityOnceCtx ) diff --git a/pkg/filesystem/hooks.go b/pkg/filesystem/hooks.go index eb8bbed..3b5755d 100644 --- a/pkg/filesystem/hooks.go +++ b/pkg/filesystem/hooks.go @@ -5,6 +5,7 @@ import ( "errors" "io/ioutil" "strings" + "sync" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/conf" @@ -186,12 +187,30 @@ func HookClearFileSize(ctx context.Context, fs *FileSystem) error { return originFile.UpdateSize(0) } +// HookCancelContext 取消上下文 +func HookCancelContext(ctx context.Context, fs *FileSystem) error { + cancelFunc, ok := ctx.Value(fsctx.CancelFuncCtx).(context.CancelFunc) + if ok { + cancelFunc() + } + return nil +} + // HookGiveBackCapacity 归还用户容量 func HookGiveBackCapacity(ctx context.Context, fs *FileSystem) error { file := ctx.Value(fsctx.FileHeaderCtx).(FileHeader) + once, ok := ctx.Value(fsctx.ValidateCapacityOnceCtx).(*sync.Once) + if !ok { + once = &sync.Once{} + } // 归还用户容量 - if !fs.User.DeductionStorage(file.GetSize()) { + res := true + once.Do(func() { + res = fs.User.DeductionStorage(file.GetSize()) + }) + + if !res { return errors.New("无法继续降低用户已用存储") } return nil diff --git a/pkg/filesystem/hooks_test.go b/pkg/filesystem/hooks_test.go index 59fe56b..2548430 100644 --- a/pkg/filesystem/hooks_test.go +++ b/pkg/filesystem/hooks_test.go @@ -8,6 +8,7 @@ import ( "net/http" "os" "strings" + "sync" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -654,3 +655,56 @@ func TestFileSystem_CleanHooks(t *testing.T) { asserts.Len(fs.Hooks, 0) } } + +func TestHookCancelContext(t *testing.T) { + asserts := assert.New(t) + fs := &FileSystem{} + ctx, cancel := context.WithCancel(context.Background()) + + // empty ctx + { + asserts.NoError(HookCancelContext(ctx, fs)) + select { + case <-ctx.Done(): + t.Errorf("Channel should not be closed") + default: + + } + } + + // with cancel ctx + { + ctx = context.WithValue(ctx, fsctx.CancelFuncCtx, cancel) + asserts.NoError(HookCancelContext(ctx, fs)) + _, ok := <-ctx.Done() + asserts.False(ok) + } +} + +func TestHookGiveBackCapacity(t *testing.T) { + asserts := assert.New(t) + fs := &FileSystem{ + User: &model.User{ + Model: gorm.Model{ID: 1}, + Storage: 10, + }, + } + ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, local.FileStream{Size: 1}) + + // without once limit + { + asserts.NoError(HookGiveBackCapacity(ctx, fs)) + asserts.EqualValues(9, fs.User.Storage) + asserts.NoError(HookGiveBackCapacity(ctx, fs)) + asserts.EqualValues(8, fs.User.Storage) + } + + // with once limit + { + ctx = context.WithValue(ctx, fsctx.ValidateCapacityOnceCtx, &sync.Once{}) + asserts.NoError(HookGiveBackCapacity(ctx, fs)) + asserts.EqualValues(7, fs.User.Storage) + asserts.NoError(HookGiveBackCapacity(ctx, fs)) + asserts.EqualValues(7, fs.User.Storage) + } +} diff --git a/pkg/webdav/webdav.go b/pkg/webdav/webdav.go index a4f9f20..dbfbdd3 100644 --- a/pkg/webdav/webdav.go +++ b/pkg/webdav/webdav.go @@ -14,6 +14,7 @@ import ( "path" "strconv" "strings" + "sync" "time" model "github.com/cloudreve/Cloudreve/v3/models" @@ -315,6 +316,8 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request, fs *filesyst ctx, cancel := context.WithCancel(context.Background()) defer cancel() ctx = context.WithValue(ctx, fsctx.HTTPCtx, r.Context()) + ctx = context.WithValue(ctx, fsctx.CancelFuncCtx, cancel) + ctx = context.WithValue(ctx, fsctx.ValidateCapacityOnceCtx, &sync.Once{}) fileSize, err := strconv.ParseUint(r.Header.Get("Content-Length"), 10, 64) if err != nil { @@ -351,6 +354,7 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request, fs *filesyst fs.Use("AfterUploadCanceled", filesystem.HookCleanFileContent) fs.Use("AfterUploadCanceled", filesystem.HookClearFileSize) fs.Use("AfterUploadCanceled", filesystem.HookGiveBackCapacity) + fs.Use("AfterUploadCanceled", filesystem.HookCancelContext) fs.Use("AfterUpload", filesystem.GenericAfterUpdate) fs.Use("AfterValidateFailed", filesystem.HookCleanFileContent) fs.Use("AfterValidateFailed", filesystem.HookClearFileSize) @@ -362,6 +366,7 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request, fs *filesyst fs.Use("BeforeUpload", filesystem.HookValidateCapacity) fs.Use("AfterUploadCanceled", filesystem.HookDeleteTempFile) fs.Use("AfterUploadCanceled", filesystem.HookGiveBackCapacity) + fs.Use("AfterUploadCanceled", filesystem.HookCancelContext) fs.Use("AfterUpload", filesystem.GenericAfterUpload) fs.Use("AfterValidateFailed", filesystem.HookDeleteTempFile) fs.Use("AfterValidateFailed", filesystem.HookGiveBackCapacity)