From ca07a94d417e943062f77dd29547dcc92aa42229 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Sat, 14 Dec 2019 14:28:01 +0800 Subject: [PATCH] Test: file compress / download --- models/file.go | 6 ++ models/file_test.go | 3 + models/folder_test.go | 3 + models/migration.go | 5 +- pkg/auth/hmac_test.go | 12 +++ pkg/cache/driver_test.go | 9 +++ pkg/filesystem/archive_test.go | 59 ++++++++++++++ pkg/filesystem/file.go | 2 +- pkg/filesystem/file_test.go | 108 ++++++++++++++++++++++++++ pkg/filesystem/local/handller_test.go | 33 ++++++++ pkg/filesystem/manage_test.go | 108 ++++++++++++++++++++++++++ pkg/filesystem/tests/file1.txt | 0 pkg/filesystem/tests/file2.txt | 0 routers/controllers/file.go | 19 +++++ routers/router.go | 2 + 15 files changed, 367 insertions(+), 2 deletions(-) create mode 100644 pkg/filesystem/archive_test.go create mode 100644 pkg/filesystem/tests/file1.txt create mode 100644 pkg/filesystem/tests/file2.txt diff --git a/models/file.go b/models/file.go index e9fe430..21f38da 100644 --- a/models/file.go +++ b/models/file.go @@ -1,6 +1,7 @@ package model import ( + "encoding/gob" "github.com/HFO4/cloudreve/pkg/util" "github.com/jinzhu/gorm" "path" @@ -25,6 +26,11 @@ type File struct { Position string `gorm:"-"` } +func init() { + // 注册缓存用到的复杂结构 + gob.Register(File{}) +} + // Create 创建文件记录 func (file *File) Create() (uint, error) { if err := DB.Create(file).Error; err != nil { diff --git a/models/file_test.go b/models/file_test.go index 99d5625..f992dd7 100644 --- a/models/file_test.go +++ b/models/file_test.go @@ -64,6 +64,8 @@ func TestFolder_GetChildFiles(t *testing.T) { Model: gorm.Model{ ID: 1, }, + Position: "/123", + Name: "456", } // 找不到 @@ -78,6 +80,7 @@ func TestFolder_GetChildFiles(t *testing.T) { files, err = folder.GetChildFiles() asserts.NoError(err) asserts.Len(files, 2) + asserts.Equal("/123/456", files[0].Position) asserts.NoError(mock.ExpectationsWereMet()) } diff --git a/models/folder_test.go b/models/folder_test.go index 65e99a5..2041f89 100644 --- a/models/folder_test.go +++ b/models/folder_test.go @@ -73,6 +73,8 @@ func TestFolder_GetChildFolder(t *testing.T) { Model: gorm.Model{ ID: 1, }, + Position: "/123", + Name: "456", } // 找不到 @@ -87,6 +89,7 @@ func TestFolder_GetChildFolder(t *testing.T) { files, err = folder.GetChildFolder() asserts.NoError(err) asserts.Len(files, 2) + asserts.Equal("/123/456", files[0].Position) asserts.NoError(mock.ExpectationsWereMet()) } diff --git a/models/migration.go b/models/migration.go index fab53bc..9a2c5be 100644 --- a/models/migration.go +++ b/models/migration.go @@ -67,6 +67,9 @@ func addDefaultPolicy() { DirNameRule: "uploads/{uid}/{path}", FileNameRule: "{uid}_{randomkey8}_{originname}", IsOriginLinkEnable: false, + OptionsSerialized: PolicyOption{ + FileType: []string{}, + }, } if err := DB.Create(&defaultPolicy).Error; err != nil { util.Log().Panic("无法创建初始存储策略, %s", err) @@ -76,7 +79,7 @@ func addDefaultPolicy() { func addDefaultSettings() { defaultSettings := []Setting{ - {Name: "siteURL", Value: `http://lite.aoaoao.me/`, Type: "basic"}, + {Name: "siteURL", Value: ``, Type: "basic"}, {Name: "siteName", Value: `Cloudreve`, Type: "basic"}, {Name: "siteStatus", Value: `open`, Type: "basic"}, {Name: "regStatus", Value: `0`, Type: "register"}, diff --git a/pkg/auth/hmac_test.go b/pkg/auth/hmac_test.go index 376a9c1..641a6df 100644 --- a/pkg/auth/hmac_test.go +++ b/pkg/auth/hmac_test.go @@ -2,6 +2,7 @@ package auth import ( "database/sql" + "fmt" "github.com/DATA-DOG/go-sqlmock" model "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/util" @@ -9,6 +10,7 @@ import ( "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" "testing" + "time" ) var mock sqlmock.Sqlmock @@ -64,6 +66,16 @@ func TestHMACAuth_Check(t *testing.T) { sign := auth.Sign("content", 1) asserts.Error(auth.Check("content", sign+":")) } + + // 过期日期格式错误 + { + asserts.Error(auth.Check("content", "ErrAuthFailed:ErrAuthFailed")) + } + + // 签名有误 + { + asserts.Error(auth.Check("content", fmt.Sprintf("sign:%d", time.Now().Unix()+10))) + } } func TestInit(t *testing.T) { diff --git a/pkg/cache/driver_test.go b/pkg/cache/driver_test.go index a0d0710..d30a67f 100644 --- a/pkg/cache/driver_test.go +++ b/pkg/cache/driver_test.go @@ -23,6 +23,15 @@ func TestGet(t *testing.T) { asserts.False(ok) } +func TestDeletes(t *testing.T) { + asserts := assert.New(t) + asserts.NoError(Set("123", "321", -1)) + err := Deletes([]string{"123"}, "") + asserts.NoError(err) + _, exist := Get("123") + asserts.False(exist) +} + func TestGetSettings(t *testing.T) { asserts := assert.New(t) asserts.NoError(Set("test_1", "1", -1)) diff --git a/pkg/filesystem/archive_test.go b/pkg/filesystem/archive_test.go new file mode 100644 index 0000000..82e5a2f --- /dev/null +++ b/pkg/filesystem/archive_test.go @@ -0,0 +1,59 @@ +package filesystem + +import ( + "context" + "github.com/DATA-DOG/go-sqlmock" + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/cache" + "github.com/jinzhu/gorm" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestFileSystem_Compress(t *testing.T) { + asserts := assert.New(t) + ctx := context.Background() + fs := FileSystem{ + User: &model.User{Model: gorm.Model{ID: 1}}, + } + + // 成功 + { + // 查找压缩父目录 + mock.ExpectQuery("SELECT(.+)folders(.+)"). + WithArgs(1, 1). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "parent")) + // 查找顶级待压缩文件 + mock.ExpectQuery("SELECT(.+)files(.+)"). + WithArgs(1, 1). + WillReturnRows( + sqlmock.NewRows( + []string{"id", "name", "source_name", "policy_id"}). + AddRow(1, "1.txt", "tests/file1.txt", 1), + ) + asserts.NoError(cache.Set("setting_temp_path", "tests", -1)) + // 查找父目录子文件 + mock.ExpectQuery("SELECT(.+)files(.+)"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id"})) + // 查找子目录 + mock.ExpectQuery("SELECT(.+)folders(.+)"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(2, "sub")) + // 查找子目录子文件 + mock.ExpectQuery("SELECT(.+)files(.+)"). + WithArgs(2). + WillReturnRows( + sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id"}). + AddRow(2, "2.txt", "tests/file2.txt", 1), + ) + // 查找上传策略 + asserts.NoError(cache.Set("policy_1", model.Policy{Type: "local"}, -1)) + + zipFile, err := fs.Compress(ctx, []uint{1}, []uint{1}) + asserts.NoError(err) + asserts.NotEmpty(zipFile) + asserts.Contains(zipFile, "archive_") + asserts.Contains(zipFile, "tests") + } +} diff --git a/pkg/filesystem/file.go b/pkg/filesystem/file.go index bd53422..1eb33f7 100644 --- a/pkg/filesystem/file.go +++ b/pkg/filesystem/file.go @@ -208,7 +208,7 @@ func (fs *FileSystem) GetDownloadURL(ctx context.Context, path string) (string, ttl, ) if err != nil { - return "", serializer.NewError(serializer.CodeNotSet, "无法获取下载地址", err) + return "", err } return source, nil diff --git a/pkg/filesystem/file_test.go b/pkg/filesystem/file_test.go index 7d0265c..f858c5b 100644 --- a/pkg/filesystem/file_test.go +++ b/pkg/filesystem/file_test.go @@ -380,3 +380,111 @@ func TestFileSystem_GetSource(t *testing.T) { fs.CleanTargets() } } + +func TestFileSystem_GetDownloadURL(t *testing.T) { + asserts := assert.New(t) + ctx := context.Background() + fs := FileSystem{ + User: &model.User{Model: gorm.Model{ID: 1}}, + } + auth.General = auth.HMACAuth{SecretKey: []byte("123")} + + // 正常 + { + err := cache.Deletes([]string{"siteURL"}, "setting_") + err = cache.Deletes([]string{"35"}, "policy_") + err = cache.Deletes([]string{"download_timeout"}, "setting_") + asserts.NoError(err) + // 查找文件 + mock.ExpectQuery("SELECT(.+)"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"}).AddRow(1, "1.txt", 1)) + // 查找上传策略 + mock.ExpectQuery("SELECT(.+)"). + WillReturnRows( + sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}). + AddRow(35, "local", true), + ) + // 相关设置 + mock.ExpectQuery("SELECT(.+)").WithArgs("siteURL").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "https://cloudreve.org")) + mock.ExpectQuery("SELECT(.+)").WithArgs("download_timeout").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "20")) + + downloadURL, err := fs.GetDownloadURL(ctx, "/1.txt") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + asserts.NotEmpty(downloadURL) + fs.CleanTargets() + } + + // 文件不存在 + { + err := cache.Deletes([]string{"siteURL"}, "setting_") + err = cache.Deletes([]string{"35"}, "policy_") + err = cache.Deletes([]string{"download_timeout"}, "setting_") + asserts.NoError(err) + // 查找文件 + mock.ExpectQuery("SELECT(.+)"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"})) + + downloadURL, err := fs.GetDownloadURL(ctx, "/1.txt") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.Empty(downloadURL) + fs.CleanTargets() + } + + // 未知存储策略 + { + err := cache.Deletes([]string{"siteURL"}, "setting_") + err = cache.Deletes([]string{"35"}, "policy_") + err = cache.Deletes([]string{"download_timeout"}, "setting_") + asserts.NoError(err) + // 查找文件 + mock.ExpectQuery("SELECT(.+)"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"}).AddRow(1, "1.txt", 1)) + // 查找上传策略 + mock.ExpectQuery("SELECT(.+)"). + WillReturnRows( + sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}). + AddRow(35, "unknown", true), + ) + + downloadURL, err := fs.GetDownloadURL(ctx, "/1.txt") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.Empty(downloadURL) + fs.CleanTargets() + } +} + +func TestFileSystem_GetPhysicalFileContent(t *testing.T) { + asserts := assert.New(t) + ctx := context.Background() + fs := FileSystem{ + User: &model.User{}, + } + + // 文件不存在 + { + rs, err := fs.GetPhysicalFileContent(ctx, "not_exist.txt") + asserts.Error(err) + asserts.Nil(rs) + } + + // 成功 + { + testFile, err := os.Create("GetPhysicalFileContent.txt") + asserts.NoError(err) + asserts.NoError(testFile.Close()) + + rs, err := fs.GetPhysicalFileContent(ctx, "GetPhysicalFileContent.txt") + asserts.NoError(err) + asserts.NoError(rs.Close()) + asserts.NotNil(rs) + } +} diff --git a/pkg/filesystem/local/handller_test.go b/pkg/filesystem/local/handller_test.go index 258c02f..6f4ddab 100644 --- a/pkg/filesystem/local/handller_test.go +++ b/pkg/filesystem/local/handller_test.go @@ -151,3 +151,36 @@ func TestHandler_Source(t *testing.T) { asserts.Empty(sourceURL) } } + +func TestHandler_GetDownloadURL(t *testing.T) { + asserts := assert.New(t) + handler := Handler{} + ctx := context.Background() + auth.General = auth.HMACAuth{SecretKey: []byte("test")} + + // 成功 + { + file := model.File{ + Model: gorm.Model{ + ID: 1, + }, + Name: "test.jpg", + } + ctx := context.WithValue(ctx, fsctx.FileModelCtx, file) + baseURL, err := url.Parse("https://cloudreve.org") + asserts.NoError(err) + downloadURL, err := handler.GetDownloadURL(ctx, "", *baseURL, 10) + asserts.NoError(err) + asserts.Contains(downloadURL, "sign=") + asserts.Contains(downloadURL, "https://cloudreve.org") + } + + // 无法获取上下文 + { + baseURL, err := url.Parse("https://cloudreve.org") + asserts.NoError(err) + downloadURL, err := handler.GetDownloadURL(ctx, "", *baseURL, 10) + asserts.Error(err) + asserts.Empty(downloadURL) + } +} diff --git a/pkg/filesystem/manage_test.go b/pkg/filesystem/manage_test.go index dc2399b..68a7d89 100644 --- a/pkg/filesystem/manage_test.go +++ b/pkg/filesystem/manage_test.go @@ -473,3 +473,111 @@ func TestFileSystem_Move(t *testing.T) { asserts.NoError(mock.ExpectationsWereMet()) } } + +func TestFileSystem_Rename(t *testing.T) { + asserts := assert.New(t) + fs := &FileSystem{User: &model.User{ + Model: gorm.Model{ + ID: 1, + }, + }} + ctx := context.Background() + + // 重命名文件 成功 + { + mock.ExpectQuery("SELECT(.+)files(.+)"). + WithArgs(10, 1). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(10, "old.text")) + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)files(.+)"). + WithArgs("new.txt", sqlmock.AnyArg(), 10). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + err := fs.Rename(ctx, []uint{}, []uint{10}, "new.txt") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + } + + // 重命名文件 不存在 + { + mock.ExpectQuery("SELECT(.+)files(.+)"). + WithArgs(10, 1). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) + err := fs.Rename(ctx, []uint{}, []uint{10}, "new.txt") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.Equal(ErrPathNotExist, err) + } + + // 重命名文件 失败 + { + mock.ExpectQuery("SELECT(.+)files(.+)"). + WithArgs(10, 1). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(10, "old.text")) + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)files(.+)"). + WithArgs("new.txt", sqlmock.AnyArg(), 10). + WillReturnError(errors.New("error")) + mock.ExpectRollback() + err := fs.Rename(ctx, []uint{}, []uint{10}, "new.txt") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.Equal(ErrFileExisted, err) + } + + // 重命名目录 成功 + { + mock.ExpectQuery("SELECT(.+)folders(.+)"). + WithArgs(10, 1). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(10, "old")) + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)folders(.+)"). + WithArgs("new", sqlmock.AnyArg(), 10). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + err := fs.Rename(ctx, []uint{10}, []uint{}, "new") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + } + + // 重命名目录 不存在 + { + mock.ExpectQuery("SELECT(.+)folders(.+)"). + WithArgs(10, 1). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) + err := fs.Rename(ctx, []uint{10}, []uint{}, "new") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.Equal(ErrPathNotExist, err) + } + + // 重命名目录 失败 + { + mock.ExpectQuery("SELECT(.+)folders(.+)"). + WithArgs(10, 1). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(10, "old")) + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)folders(.+)"). + WithArgs("new", sqlmock.AnyArg(), 10). + WillReturnError(errors.New("error")) + mock.ExpectRollback() + err := fs.Rename(ctx, []uint{10}, []uint{}, "new") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.Equal(ErrFileExisted, err) + } + + // 未选中任何对象 + { + err := fs.Rename(ctx, []uint{}, []uint{}, "new") + asserts.Error(err) + asserts.Equal(ErrPathNotExist, err) + } + + // 新名字不合法 + { + err := fs.Rename(ctx, []uint{10}, []uint{}, "ne/w") + asserts.Error(err) + asserts.Equal(ErrIllegalObjectName, err) + } +} diff --git a/pkg/filesystem/tests/file1.txt b/pkg/filesystem/tests/file1.txt new file mode 100644 index 0000000..e69de29 diff --git a/pkg/filesystem/tests/file2.txt b/pkg/filesystem/tests/file2.txt new file mode 100644 index 0000000..e69de29 diff --git a/routers/controllers/file.go b/routers/controllers/file.go index e9d651e..3f6fcd4 100644 --- a/routers/controllers/file.go +++ b/routers/controllers/file.go @@ -137,6 +137,25 @@ func Thumb(c *gin.Context) { } +// RedirectToDownload 创建下载会话并重定向至下载地址 +func RedirectToDownload(c *gin.Context) { + // 创建上下文 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var service explorer.FileDownloadCreateService + if err := c.ShouldBindUri(&service); err == nil { + res := service.CreateDownloadSession(ctx, c) + if res.Code == 0 { + c.Redirect(301, res.Data.(string)) + return + } + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + // CreateDownloadSession 创建文件下载会话 func CreateDownloadSession(c *gin.Context) { // 创建上下文 diff --git a/routers/router.go b/routers/router.go index 9ae811e..928fa68 100644 --- a/routers/router.go +++ b/routers/router.go @@ -106,6 +106,8 @@ func InitRouter() *gin.Engine { file.POST("upload", controllers.FileUploadStream) // 创建文件下载会话 file.PUT("download/*path", controllers.CreateDownloadSession) + // 创建文件下载并重定向到下载地址 + file.GET("redirect/*path", controllers.RedirectToDownload) // 获取缩略图 file.GET("thumb/:id", controllers.Thumb) // 取得文件外链