From 521c5c8dc4fa075bf1f9ad5df803cf3245336ca2 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Sun, 27 Feb 2022 14:24:17 +0800 Subject: [PATCH] Feat: use transactions to manipulate user's used storage --- models/file.go | 19 ++++++++++--- pkg/aria2/monitor/monitor.go | 2 +- pkg/filesystem/file.go | 3 ++- pkg/filesystem/fsctx/context.go | 2 -- pkg/filesystem/hooks.go | 47 +++++++++------------------------ pkg/filesystem/hooks_test.go | 4 +-- pkg/filesystem/upload.go | 41 ++++++++++++---------------- pkg/task/import.go | 7 ++++- pkg/webdav/webdav.go | 5 ---- service/callback/upload.go | 1 - service/explorer/upload.go | 6 +---- 11 files changed, 56 insertions(+), 81 deletions(-) diff --git a/models/file.go b/models/file.go index e78f6d0..34918bb 100644 --- a/models/file.go +++ b/models/file.go @@ -39,12 +39,23 @@ func init() { } // Create 创建文件记录 -func (file *File) Create() (uint, error) { - if err := DB.Create(file).Error; err != nil { +func (file *File) Create() error { + tx := DB.Begin() + + if err := tx.Create(file).Error; err != nil { util.Log().Warning("无法插入文件记录, %s", err) - return 0, err + tx.Rollback() + return err } - return file.ID, nil + + user := &User{} + user.ID = file.UserID + if err := user.ChangeStorage(tx, "+", file.Size); err != nil { + tx.Rollback() + return err + } + + return tx.Commit().Error } // AfterFind 找到文件后的钩子 diff --git a/pkg/aria2/monitor/monitor.go b/pkg/aria2/monitor/monitor.go index 1f8954b..a515b66 100644 --- a/pkg/aria2/monitor/monitor.go +++ b/pkg/aria2/monitor/monitor.go @@ -195,7 +195,7 @@ func (monitor *Monitor) ValidateFile() error { } // 验证用户容量 - if err := filesystem.HookValidateCapacityWithoutIncrease(context.Background(), fs, file); err != nil { + if err := filesystem.HookValidateCapacity(context.Background(), fs, file); err != nil { return err } diff --git a/pkg/filesystem/file.go b/pkg/filesystem/file.go index 491274f..3d244e4 100644 --- a/pkg/filesystem/file.go +++ b/pkg/filesystem/file.go @@ -70,7 +70,7 @@ func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder, file fs newFile.PicInfo = "1,1" } - _, err = newFile.Create() + err = newFile.Create() if err != nil { if err := fs.Trigger(ctx, "AfterValidateFailed", file); err != nil { @@ -79,6 +79,7 @@ func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder, file fs return nil, ErrFileExisted.WithError(err) } + fs.User.Storage += newFile.Size return &newFile, nil } diff --git a/pkg/filesystem/fsctx/context.go b/pkg/filesystem/fsctx/context.go index 06664b3..942551c 100644 --- a/pkg/filesystem/fsctx/context.go +++ b/pkg/filesystem/fsctx/context.go @@ -33,8 +33,6 @@ const ( ForceUsePublicEndpointCtx // CancelFuncCtx Context 取消函數 CancelFuncCtx - // ValidateCapacityOnceCtx 限定归还容量的操作只執行一次 - ValidateCapacityOnceCtx // 文件在从机节点中的路径 SlaveSrcPath ) diff --git a/pkg/filesystem/hooks.go b/pkg/filesystem/hooks.go index 826e101..8a8951a 100644 --- a/pkg/filesystem/hooks.go +++ b/pkg/filesystem/hooks.go @@ -2,11 +2,6 @@ package filesystem import ( "context" - "errors" - "io/ioutil" - "strings" - "sync" - model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/conf" @@ -15,6 +10,8 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" + "io/ioutil" + "strings" ) // Hook 钩子函数 @@ -115,17 +112,8 @@ func HookResetPolicy(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) return fs.DispatchHandler() } -// HookValidateCapacity 验证并扣除用户容量,包含数据库操作 +// HookValidateCapacity 验证用户容量 func HookValidateCapacity(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - // 验证并扣除容量 - if !fs.ValidateCapacity(ctx, file.Info().Size) { - return ErrInsufficientCapacity - } - return nil -} - -// HookValidateCapacityWithoutIncrease 验证用户容量,不扣除 -func HookValidateCapacityWithoutIncrease(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { // 验证并扣除容量 if fs.User.GetRemainingCapacity() < file.Info().Size { return ErrInsufficientCapacity @@ -139,7 +127,7 @@ func HookValidateCapacityDiff(ctx context.Context, fs *FileSystem, newFile fsctx newFileSize := newFile.Info().Size if newFileSize > originFile.Size { - return HookValidateCapacityWithoutIncrease(ctx, fs, newFile) + return HookValidateCapacity(ctx, fs, newFile) } return nil @@ -184,25 +172,6 @@ func HookCancelContext(ctx context.Context, fs *FileSystem, file fsctx.FileHeade return nil } -// HookGiveBackCapacity 归还用户容量 -func HookGiveBackCapacity(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - once, ok := ctx.Value(fsctx.ValidateCapacityOnceCtx).(*sync.Once) - if !ok { - once = &sync.Once{} - } - - // 归还用户容量 - res := true - once.Do(func() { - res = fs.User.DeductionStorage(file.Info().Size) - }) - - if !res { - return errors.New("无法继续降低用户已用存储") - } - return nil -} - // HookUpdateSourceName 更新文件SourceName // TODO:测试 func HookUpdateSourceName(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { @@ -335,6 +304,14 @@ func HookChunkUploaded(ctx context.Context, fs *FileSystem, fileHeader fsctx.Fil return fileInfo.Model.(*model.File).UpdateSize(fileInfo.Model.(*model.File).GetSize() + fileInfo.Size) } +// HookChunkUploadFailed 单个分片上传失败后 +func HookChunkUploadFailed(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { + fileInfo := fileHeader.Info() + + // 更新文件大小 + return fileInfo.Model.(*model.File).UpdateSize(fileInfo.Model.(*model.File).GetSize() - fileInfo.Size) +} + // HookChunkUploadFinished 分片上传结束后处理文件 func HookChunkUploadFinished(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { fileInfo := fileHeader.Info() diff --git a/pkg/filesystem/hooks_test.go b/pkg/filesystem/hooks_test.go index 40f668b..a247490 100644 --- a/pkg/filesystem/hooks_test.go +++ b/pkg/filesystem/hooks_test.go @@ -723,14 +723,14 @@ func TestHookValidateCapacityWithoutIncrease(t *testing.T) { // not enough { fs.User.Group.MaxStorage = 10 - a.Error(HookValidateCapacityWithoutIncrease(ctx, fs)) + a.Error(HookValidateCapacity(ctx, fs)) a.EqualValues(10, fs.User.Storage) } // enough { fs.User.Group.MaxStorage = 11 - a.NoError(HookValidateCapacityWithoutIncrease(ctx, fs)) + a.NoError(HookValidateCapacity(ctx, fs)) a.EqualValues(10, fs.User.Storage) } } diff --git a/pkg/filesystem/upload.go b/pkg/filesystem/upload.go index a8b2716..e63731b 100644 --- a/pkg/filesystem/upload.go +++ b/pkg/filesystem/upload.go @@ -164,6 +164,23 @@ func (fs *FileSystem) CreateUploadSession(ctx context.Context, file *fsctx.FileS callbackKey := uuid.Must(uuid.NewV4()).String() fileSize := file.Size + // 创建占位的文件,同时校验文件信息 + file.Mode = fsctx.Nop + if callbackKey != "" { + file.UploadSessionID = &callbackKey + } + + fs.Use("BeforeUpload", HookValidateFile) + fs.Use("BeforeUpload", HookValidateCapacity) + if !fs.Policy.IsUploadPlaceholderWithSize() { + fs.Use("AfterUpload", HookClearFileHeaderSize) + } + + fs.Use("AfterUpload", GenericAfterUpload) + if err := fs.Upload(ctx, file); err != nil { + return nil, err + } + uploadSession := &serializer.UploadSession{ Key: callbackKey, UID: fs.User.ID, @@ -181,27 +198,6 @@ func (fs *FileSystem) CreateUploadSession(ctx context.Context, file *fsctx.FileS return nil, err } - // 创建占位的文件,同时校验文件信息 - file.Mode = fsctx.Nop - if callbackKey != "" { - file.UploadSessionID = &callbackKey - } - - fs.Use("BeforeUpload", HookValidateFile) - if !fs.Policy.IsUploadPlaceholderWithSize() { - fs.Use("BeforeUpload", HookValidateCapacityWithoutIncrease) - fs.Use("AfterUpload", HookClearFileHeaderSize) - } else { - fs.Use("BeforeUpload", HookValidateCapacity) - fs.Use("AfterValidateFailed", HookGiveBackCapacity) - fs.Use("AfterUploadFailed", HookGiveBackCapacity) - } - - fs.Use("AfterUpload", GenericAfterUpload) - if err := fs.Upload(ctx, file); err != nil { - return nil, err - } - // 创建回调会话 err = cache.Set( UploadSessionCachePrefix+callbackKey, @@ -226,12 +222,9 @@ func (fs *FileSystem) UploadFromStream(ctx context.Context, file *fsctx.FileStre fs.Use("BeforeUpload", HookValidateFile) fs.Use("BeforeUpload", HookValidateCapacity) fs.Use("AfterUploadCanceled", HookDeleteTempFile) - fs.Use("AfterUploadCanceled", HookGiveBackCapacity) fs.Use("AfterUpload", GenericAfterUpload) fs.Use("AfterUpload", HookGenerateThumb) fs.Use("AfterValidateFailed", HookDeleteTempFile) - fs.Use("AfterValidateFailed", HookGiveBackCapacity) - fs.Use("AfterUploadFailed", HookGiveBackCapacity) } fs.Lock.Unlock() diff --git a/pkg/task/import.go b/pkg/task/import.go index afdaaa2..e91bbed 100644 --- a/pkg/task/import.go +++ b/pkg/task/import.go @@ -94,10 +94,15 @@ func (job *ImportTask) Do() { } defer fs.Recycle() + fs.Policy = &policy + if err := fs.DispatchHandler(); err != nil { + job.SetErrorMsg("无法分发存储策略", err) + return + } + // 注册钩子 fs.Use("BeforeAddFile", filesystem.HookValidateFile) fs.Use("BeforeAddFile", filesystem.HookValidateCapacity) - fs.Use("AfterValidateFailed", filesystem.HookGiveBackCapacity) // 列取目录、对象 job.TaskModel.SetProgress(ListingProgress) diff --git a/pkg/webdav/webdav.go b/pkg/webdav/webdav.go index 8d6ff56..7f98dfa 100644 --- a/pkg/webdav/webdav.go +++ b/pkg/webdav/webdav.go @@ -14,7 +14,6 @@ import ( "path" "strconv" "strings" - "sync" "time" model "github.com/cloudreve/Cloudreve/v3/models" @@ -316,7 +315,6 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request, fs *filesyst defer cancel() ctx = context.WithValue(ctx, fsctx.HTTPCtx, r.Context()) ctx = context.WithValue(ctx, fsctx.CancelFuncCtx, cancel) - ctx = context.WithValue(ctx, fsctx.ValidateCapacityOnceCtx, &sync.Once{}) fileSize, err := strconv.ParseUint(r.Header.Get("Content-Length"), 10, 64) if err != nil { @@ -362,13 +360,10 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request, fs *filesyst fs.Use("BeforeUpload", filesystem.HookValidateFile) fs.Use("BeforeUpload", filesystem.HookValidateCapacity) fs.Use("AfterUploadCanceled", filesystem.HookDeleteTempFile) - fs.Use("AfterUploadCanceled", filesystem.HookGiveBackCapacity) fs.Use("AfterUploadCanceled", filesystem.HookCancelContext) fs.Use("AfterUpload", filesystem.GenericAfterUpload) fs.Use("AfterUpload", filesystem.HookGenerateThumb) fs.Use("AfterValidateFailed", filesystem.HookDeleteTempFile) - fs.Use("AfterValidateFailed", filesystem.HookGiveBackCapacity) - fs.Use("AfterUploadFailed", filesystem.HookGiveBackCapacity) // 禁止覆盖 fileData.Mode = fsctx.Create diff --git a/service/callback/upload.go b/service/callback/upload.go index 8900375..54bdf74 100644 --- a/service/callback/upload.go +++ b/service/callback/upload.go @@ -160,7 +160,6 @@ func ProcessCallback(service CallbackProcessService, c *gin.Context) serializer. // 添加钩子 fs.Use("BeforeAddFile", filesystem.HookValidateFile) fs.Use("BeforeAddFile", filesystem.HookValidateCapacity) - fs.Use("AfterValidateFailed", filesystem.HookGiveBackCapacity) fs.Use("AfterValidateFailed", filesystem.HookDeleteTempFile) fs.Use("BeforeAddFileFailed", filesystem.HookDeleteTempFile) diff --git a/service/explorer/upload.go b/service/explorer/upload.go index 4cfce8c..2a044b0 100644 --- a/service/explorer/upload.go +++ b/service/explorer/upload.go @@ -15,7 +15,6 @@ import ( "io/ioutil" "strconv" "strings" - "sync" "time" ) @@ -159,7 +158,6 @@ func processChunkUpload(ctx context.Context, c *gin.Context, fs *filesystem.File // 给文件系统分配钩子 fs.Use("BeforeUpload", filesystem.HookValidateCapacity) fs.Use("AfterUploadCanceled", filesystem.HookTruncateFileTo(fileData.AppendStart)) - fs.Use("AfterUploadCanceled", filesystem.HookGiveBackCapacity) fs.Use("AfterUpload", filesystem.HookChunkUploaded) if isLastChunk { fs.Use("AfterUpload", filesystem.HookChunkUploadFinished) @@ -167,11 +165,9 @@ func processChunkUpload(ctx context.Context, c *gin.Context, fs *filesystem.File fs.Use("AfterUpload", filesystem.HookDeleteUploadSession(session.Key)) } fs.Use("AfterValidateFailed", filesystem.HookTruncateFileTo(fileData.AppendStart)) - fs.Use("AfterValidateFailed", filesystem.HookGiveBackCapacity) - fs.Use("AfterUploadFailed", filesystem.HookGiveBackCapacity) + fs.Use("AfterValidateFailed", filesystem.HookChunkUploadFailed) // 执行上传 - ctx = context.WithValue(ctx, fsctx.ValidateCapacityOnceCtx, &sync.Once{}) uploadCtx := context.WithValue(ctx, fsctx.GinCtx, c) err = fs.Upload(uploadCtx, &fileData) if err != nil {