mirror of https://github.com/cloudreve/Cloudreve
Feat: basic file validator
parent
003274162b
commit
79caf635f9
|
@ -54,6 +54,24 @@ type UserOption struct {
|
||||||
WebDAVKey string `json:"webdav_key"`
|
WebDAVKey string `json:"webdav_key"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeductionCapacity 扣除用户容量配额
|
||||||
|
func (user *User) DeductionCapacity(size uint64) bool {
|
||||||
|
if size <= user.GetRemainingCapacity() {
|
||||||
|
user.Storage += size
|
||||||
|
DB.Save(user)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRemainingCapacity 获取剩余配额
|
||||||
|
func (user *User) GetRemainingCapacity() uint64 {
|
||||||
|
if user.Group.MaxStorage <= user.Storage {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return user.Group.MaxStorage - user.Storage
|
||||||
|
}
|
||||||
|
|
||||||
// GetPolicyID 获取用户当前的上传策略ID
|
// GetPolicyID 获取用户当前的上传策略ID
|
||||||
func (user *User) GetPolicyID() uint {
|
func (user *User) GetPolicyID() uint {
|
||||||
// 用户未指定时,返回可用的第一个
|
// 用户未指定时,返回可用的第一个
|
||||||
|
|
|
@ -164,3 +164,55 @@ func TestUser_GetPolicyID(t *testing.T) {
|
||||||
asserts.Equal(testCase.expected, newUser.GetPolicyID(), "测试用例 #%d 未通过", key)
|
asserts.Equal(testCase.expected, newUser.GetPolicyID(), "测试用例 #%d 未通过", key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUser_GetRemainingCapacity(t *testing.T) {
|
||||||
|
asserts := assert.New(t)
|
||||||
|
newUser := NewUser()
|
||||||
|
|
||||||
|
newUser.Group.MaxStorage = 100
|
||||||
|
asserts.Equal(uint64(100), newUser.GetRemainingCapacity())
|
||||||
|
|
||||||
|
newUser.Group.MaxStorage = 100
|
||||||
|
newUser.Storage = 1
|
||||||
|
asserts.Equal(uint64(99), newUser.GetRemainingCapacity())
|
||||||
|
|
||||||
|
newUser.Group.MaxStorage = 100
|
||||||
|
newUser.Storage = 100
|
||||||
|
asserts.Equal(uint64(0), newUser.GetRemainingCapacity())
|
||||||
|
|
||||||
|
newUser.Group.MaxStorage = 100
|
||||||
|
newUser.Storage = 200
|
||||||
|
asserts.Equal(uint64(0), newUser.GetRemainingCapacity())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUser_DeductionCapacity(t *testing.T) {
|
||||||
|
asserts := assert.New(t)
|
||||||
|
|
||||||
|
userRows := sqlmock.NewRows([]string{"id", "deleted_at", "storage", "options", "group_id"}).
|
||||||
|
AddRow(1, nil, 0, "{}", 1)
|
||||||
|
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(userRows)
|
||||||
|
groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}).
|
||||||
|
AddRow(1, "管理员", "[1]")
|
||||||
|
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows)
|
||||||
|
|
||||||
|
policyRows := sqlmock.NewRows([]string{"id", "name"}).
|
||||||
|
AddRow(1, "默认上传策略")
|
||||||
|
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows)
|
||||||
|
|
||||||
|
newUser, err := GetUserByID(1)
|
||||||
|
newUser.Group.MaxStorage = 100
|
||||||
|
asserts.NoError(err)
|
||||||
|
asserts.NoError(mock.ExpectationsWereMet())
|
||||||
|
|
||||||
|
asserts.Equal(false, newUser.DeductionCapacity(101))
|
||||||
|
asserts.Equal(uint64(0), newUser.Storage)
|
||||||
|
|
||||||
|
asserts.Equal(true, newUser.DeductionCapacity(1))
|
||||||
|
asserts.Equal(uint64(1), newUser.Storage)
|
||||||
|
|
||||||
|
asserts.Equal(true, newUser.DeductionCapacity(99))
|
||||||
|
asserts.Equal(uint64(100), newUser.Storage)
|
||||||
|
|
||||||
|
asserts.Equal(false, newUser.DeductionCapacity(1))
|
||||||
|
asserts.Equal(uint64(100), newUser.Storage)
|
||||||
|
}
|
||||||
|
|
|
@ -11,18 +11,37 @@ type FileData interface {
|
||||||
io.Closer
|
io.Closer
|
||||||
GetSize() uint64
|
GetSize() uint64
|
||||||
GetMIMEType() string
|
GetMIMEType() string
|
||||||
|
GetFileName() string
|
||||||
}
|
}
|
||||||
|
|
||||||
// FileSystem 管理文件的文件系统
|
// FileSystem 管理文件的文件系统
|
||||||
type FileSystem struct {
|
type FileSystem struct {
|
||||||
// 文件系统所有者
|
/*
|
||||||
|
文件系统所有者
|
||||||
|
*/
|
||||||
User *model.User
|
User *model.User
|
||||||
|
|
||||||
// 文件系统处理适配器
|
/*
|
||||||
|
钩子函数
|
||||||
|
*/
|
||||||
|
// 上传文件前
|
||||||
|
BeforeUpload func(fs *FileSystem, file FileData) error
|
||||||
|
// 上传文件后
|
||||||
|
AfterUpload func(fs *FileSystem) error
|
||||||
|
// 文件验证失败后
|
||||||
|
ValidateFailed func(fs *FileSystem) error
|
||||||
|
|
||||||
|
/*
|
||||||
|
文件系统处理适配器
|
||||||
|
*/
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Upload 上传文件
|
// Upload 上传文件
|
||||||
func (fs *FileSystem) Upload(File FileData) (err error) {
|
func (fs *FileSystem) Upload(file FileData) (err error) {
|
||||||
|
err = fs.BeforeUpload(fs, file)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
package filesystem
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
// GenericBeforeUpload 通用上传前处理钩子,包含数据库操作
|
||||||
|
func GenericBeforeUpload(fs *FileSystem, file FileData) error {
|
||||||
|
// 验证单文件尺寸
|
||||||
|
if !fs.ValidateFileSize(file.GetSize()) {
|
||||||
|
return errors.New("单个文件尺寸太大")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证并扣除容量
|
||||||
|
if !fs.ValidateCapacity(file.GetSize()) {
|
||||||
|
return errors.New("容量空间不足")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证扩展名
|
||||||
|
if !fs.ValidateExtension(file.GetFileName()) {
|
||||||
|
return errors.New("不允许上传此类型的文件")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -6,6 +6,7 @@ import "mime/multipart"
|
||||||
type FileData struct {
|
type FileData struct {
|
||||||
File multipart.File
|
File multipart.File
|
||||||
Size uint64
|
Size uint64
|
||||||
|
Name string
|
||||||
MIMEType string
|
MIMEType string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,3 +25,7 @@ func (file FileData) GetSize() uint64 {
|
||||||
func (file FileData) Close() error {
|
func (file FileData) Close() error {
|
||||||
return file.Close()
|
return file.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (file FileData) GetFileName() string {
|
||||||
|
return file.Name
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
package filesystem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cloudreve/pkg/util"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValidateFileSize 验证上传的文件大小是否超出限制
|
||||||
|
func (fs *FileSystem) ValidateFileSize(size uint64) bool {
|
||||||
|
return size <= fs.User.Policy.MaxSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateCapacity 验证并扣除用户容量
|
||||||
|
func (fs *FileSystem) ValidateCapacity(size uint64) bool {
|
||||||
|
if fs.User.DeductionCapacity(size) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateExtension 验证文件扩展名
|
||||||
|
func (fs *FileSystem) ValidateExtension(fileName string) bool {
|
||||||
|
// 不需要验证
|
||||||
|
if len(fs.User.Policy.OptionsSerialized.FileType) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
ext := filepath.Ext(fileName)
|
||||||
|
|
||||||
|
// 无扩展名时
|
||||||
|
if len(ext) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if util.ContainsString(fs.User.Policy.OptionsSerialized.FileType, ext[1:]) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
|
@ -19,6 +19,8 @@ const (
|
||||||
CodeCheckLogin = 401
|
CodeCheckLogin = 401
|
||||||
// CodeNoRightErr 未授权访问
|
// CodeNoRightErr 未授权访问
|
||||||
CodeNoRightErr = 403
|
CodeNoRightErr = 403
|
||||||
|
// CodeUploadFailed 上传出错
|
||||||
|
CodeUploadFailed = 4001
|
||||||
// CodeDBError 数据库操作失败
|
// CodeDBError 数据库操作失败
|
||||||
CodeDBError = 50001
|
CodeDBError = 50001
|
||||||
// CodeEncryptError 加密失败
|
// CodeEncryptError 加密失败
|
||||||
|
|
|
@ -24,3 +24,13 @@ func ContainsUint(s []uint, e uint) bool {
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ContainsString 返回list中是否包含
|
||||||
|
func ContainsString(s []string, e string) bool {
|
||||||
|
for _, a := range s {
|
||||||
|
if a == e {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
|
@ -28,14 +28,18 @@ func (service *UploadService) Upload(c *gin.Context) serializer.Response {
|
||||||
MIMEType: service.File.Header.Get("Content-Type"),
|
MIMEType: service.File.Header.Get("Content-Type"),
|
||||||
File: file,
|
File: file,
|
||||||
Size: uint64(service.File.Size),
|
Size: uint64(service.File.Size),
|
||||||
|
Name: service.Name,
|
||||||
}
|
}
|
||||||
|
|
||||||
user, _ := c.Get("user")
|
user, _ := c.Get("user")
|
||||||
|
|
||||||
fs := filesystem.FileSystem{
|
fs := filesystem.FileSystem{
|
||||||
|
BeforeUpload: filesystem.GenericBeforeUpload,
|
||||||
User: user.(*model.User),
|
User: user.(*model.User),
|
||||||
}
|
}
|
||||||
|
|
||||||
err = fs.Upload(fileData)
|
err = fs.Upload(fileData)
|
||||||
|
if err != nil {
|
||||||
|
return serializer.Err(serializer.CodeUploadFailed, err.Error(), err)
|
||||||
|
}
|
||||||
|
|
||||||
return serializer.Response{
|
return serializer.Response{
|
||||||
Code: 0,
|
Code: 0,
|
||||||
|
|
Loading…
Reference in New Issue