From aea3ba1499d56cb3de39da7f57c973197801af8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8D=83=E7=9F=B3?= Date: Fri, 15 Aug 2025 08:09:00 -0700 Subject: [PATCH] feat: add tag backup and fix bugs (#9265) * feat(label): enhance label file binding and router setup (feat/add-tag-backup) - Add `GetLabelsByFileNamesPublic` to retrieve labels using file names. - Refactor router setup for label and file binding routes. - Improve `toObjsResp` for efficient label retrieval by file names. - Comment out unnecessary user ID parameter in `toObjsResp`. * feat(label): enhance label file binding and router setup - Add `GetLabelsByFileNamesPublic` for label retrieval by file names. - Refactor router setup for label and file binding routes. - Improve `toObjsResp` for efficient label retrieval by file names. - Comment out unnecessary user ID parameter in `toObjsResp`. * refactor(db): comment out debug print in GetLabelIds (#feat/add-tag-backup) - Comment out debug print statement in GetLabelIds to clean up logs. - Enhance code readability by removing unnecessary debug output. * feat(label-file-binding): add batch creation and improve label ID handling - Introduced `CreateLabelFileBinDingBatch` API for batch label binding. - Added `collectLabelIDs` helper function to handle label ID parsing. - Enhanced label ID handling to support varied delimiters and input formats. - Refactored `CreateLabelFileBinDing` logic for improved code readability. - Updated router to include `POST /label_file_binding/create_batch`. --- internal/db/db.go | 2 +- internal/db/label_file_binding.go | 148 ++++++++++++++++++++++++-- internal/errs/role.go | 2 +- internal/model/label_file_binding.go | 2 +- internal/op/label_file_binding.go | 58 +++++++++-- server/handles/fsread.go | 20 +++- server/handles/label_file_binding.go | 149 ++++++++++++++++++++++++++- server/handles/user.go | 8 +- server/router.go | 20 +++- 9 files changed, 375 insertions(+), 34 deletions(-) diff --git a/internal/db/db.go b/internal/db/db.go index c6491dc9..0d8ab421 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -12,7 +12,7 @@ var db *gorm.DB func Init(d *gorm.DB) { db = d - err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.Role), new(model.Label), new(model.LabelFileBinDing), new(model.ObjFile)) + err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.Role), new(model.Label), new(model.LabelFileBinding), new(model.ObjFile)) if err != nil { log.Fatalf("failed migrate database: %s", err.Error()) } diff --git a/internal/db/label_file_binding.go b/internal/db/label_file_binding.go index ec722efb..4dda80f2 100644 --- a/internal/db/label_file_binding.go +++ b/internal/db/label_file_binding.go @@ -1,15 +1,18 @@ package db import ( + "fmt" "github.com/alist-org/alist/v3/internal/model" "github.com/pkg/errors" "gorm.io/gorm" + "gorm.io/gorm/clause" "time" ) // GetLabelIds Get all label_ids from database order by file_name func GetLabelIds(userId uint, fileName string) ([]uint, error) { - labelFileBinDingDB := db.Model(&model.LabelFileBinDing{}) + //fmt.Printf(">>> [GetLabelIds] userId: %d, fileName: %s\n", userId, fileName) + labelFileBinDingDB := db.Model(&model.LabelFileBinding{}) var labelIds []uint if err := labelFileBinDingDB.Where("file_name = ?", fileName).Where("user_id = ?", userId).Pluck("label_id", &labelIds).Error; err != nil { return nil, errors.WithStack(err) @@ -18,7 +21,7 @@ func GetLabelIds(userId uint, fileName string) ([]uint, error) { } func CreateLabelFileBinDing(fileName string, labelId, userId uint) error { - var labelFileBinDing model.LabelFileBinDing + var labelFileBinDing model.LabelFileBinding labelFileBinDing.UserId = userId labelFileBinDing.LabelId = labelId labelFileBinDing.FileName = fileName @@ -32,7 +35,7 @@ func CreateLabelFileBinDing(fileName string, labelId, userId uint) error { // GetLabelFileBinDingByLabelIdExists Get Label by label_id, used to del label usually func GetLabelFileBinDingByLabelIdExists(labelId, userId uint) bool { - var labelFileBinDing model.LabelFileBinDing + var labelFileBinDing model.LabelFileBinding result := db.Where("label_id = ?", labelId).Where("user_id = ?", userId).First(&labelFileBinDing) exists := !errors.Is(result.Error, gorm.ErrRecordNotFound) return exists @@ -40,17 +43,150 @@ func GetLabelFileBinDingByLabelIdExists(labelId, userId uint) bool { // DelLabelFileBinDingByFileName used to del usually func DelLabelFileBinDingByFileName(userId uint, fileName string) error { - return errors.WithStack(db.Where("file_name = ?", fileName).Where("user_id = ?", userId).Delete(model.LabelFileBinDing{}).Error) + return errors.WithStack(db.Where("file_name = ?", fileName).Where("user_id = ?", userId).Delete(model.LabelFileBinding{}).Error) } // DelLabelFileBinDingById used to del usually func DelLabelFileBinDingById(labelId, userId uint, fileName string) error { - return errors.WithStack(db.Where("label_id = ?", labelId).Where("file_name = ?", fileName).Where("user_id = ?", userId).Delete(model.LabelFileBinDing{}).Error) + return errors.WithStack(db.Where("label_id = ?", labelId).Where("file_name = ?", fileName).Where("user_id = ?", userId).Delete(model.LabelFileBinding{}).Error) } -func GetLabelFileBinDingByLabelId(labelIds []uint, userId uint) (result []model.LabelFileBinDing, err error) { +func GetLabelFileBinDingByLabelId(labelIds []uint, userId uint) (result []model.LabelFileBinding, err error) { if err := db.Where("label_id in (?)", labelIds).Where("user_id = ?", userId).Find(&result).Error; err != nil { return nil, errors.WithStack(err) } return result, nil } + +func GetLabelBindingsByFileNamesPublic(fileNames []string) (map[string][]uint, error) { + var binds []model.LabelFileBinding + if err := db.Where("file_name IN ?", fileNames).Find(&binds).Error; err != nil { + return nil, errors.WithStack(err) + } + out := make(map[string][]uint, len(fileNames)) + seen := make(map[string]struct{}, len(binds)) + for _, b := range binds { + key := fmt.Sprintf("%s-%d", b.FileName, b.LabelId) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + out[b.FileName] = append(out[b.FileName], b.LabelId) + } + return out, nil +} + +func GetLabelsByFileNamesPublic(fileNames []string) (map[string][]model.Label, error) { + bindMap, err := GetLabelBindingsByFileNamesPublic(fileNames) + if err != nil { + return nil, err + } + + idSet := make(map[uint]struct{}) + for _, ids := range bindMap { + for _, id := range ids { + idSet[id] = struct{}{} + } + } + if len(idSet) == 0 { + return make(map[string][]model.Label, 0), nil + } + allIDs := make([]uint, 0, len(idSet)) + for id := range idSet { + allIDs = append(allIDs, id) + } + labels, err := GetLabelByIds(allIDs) // 你已有的函数 + if err != nil { + return nil, err + } + + labelByID := make(map[uint]model.Label, len(labels)) + for _, l := range labels { + labelByID[l.ID] = l + } + + out := make(map[string][]model.Label, len(bindMap)) + for fname, ids := range bindMap { + for _, id := range ids { + if lab, ok := labelByID[id]; ok { + out[fname] = append(out[fname], lab) + } + } + } + return out, nil +} + +func ListLabelFileBinDing(userId uint, labelIDs []uint, fileName string, page, pageSize int) ([]model.LabelFileBinding, int64, error) { + q := db.Model(&model.LabelFileBinding{}).Where("user_id = ?", userId) + + if len(labelIDs) > 0 { + q = q.Where("label_id IN ?", labelIDs) + } + if fileName != "" { + q = q.Where("file_name LIKE ?", "%"+fileName+"%") + } + + var total int64 + if err := q.Count(&total).Error; err != nil { + return nil, 0, errors.WithStack(err) + } + + var rows []model.LabelFileBinding + if err := q. + Order("id DESC"). + Offset((page - 1) * pageSize). + Limit(pageSize). + Find(&rows).Error; err != nil { + return nil, 0, errors.WithStack(err) + } + return rows, total, nil +} + +func RestoreLabelFileBindings(bindings []model.LabelFileBinding, keepIDs bool, override bool) error { + if len(bindings) == 0 { + return nil + } + tx := db.Begin() + + if override { + type key struct { + uid uint + name string + } + toDel := make(map[key]struct{}, len(bindings)) + for i := range bindings { + k := key{uid: bindings[i].UserId, name: bindings[i].FileName} + toDel[k] = struct{}{} + } + for k := range toDel { + if err := tx.Where("user_id = ? AND file_name = ?", k.uid, k.name). + Delete(&model.LabelFileBinding{}).Error; err != nil { + tx.Rollback() + return errors.WithStack(err) + } + } + } + + for i := range bindings { + b := bindings[i] + if !keepIDs { + b.ID = 0 + } + if b.CreateTime.IsZero() { + b.CreateTime = time.Now() + } + if override { + if err := tx.Create(&b).Error; err != nil { + tx.Rollback() + return errors.WithStack(err) + } + } else { + if err := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&b).Error; err != nil { + tx.Rollback() + return errors.WithStack(err) + } + } + } + + return errors.WithStack(tx.Commit().Error) +} diff --git a/internal/errs/role.go b/internal/errs/role.go index fbd67404..a818ea21 100644 --- a/internal/errs/role.go +++ b/internal/errs/role.go @@ -3,5 +3,5 @@ package errs import "errors" var ( - ErrChangeDefaultRole = errors.New("cannot modify admin or guest role") + ErrChangeDefaultRole = errors.New("cannot modify admin role") ) diff --git a/internal/model/label_file_binding.go b/internal/model/label_file_binding.go index 3f9ea3b2..af57fed4 100644 --- a/internal/model/label_file_binding.go +++ b/internal/model/label_file_binding.go @@ -2,7 +2,7 @@ package model import "time" -type LabelFileBinDing struct { +type LabelFileBinding struct { ID uint `json:"id" gorm:"primaryKey"` // unique key UserId uint `json:"user_id"` // use to user_id LabelId uint `json:"label_id"` // use to label_id diff --git a/internal/op/label_file_binding.go b/internal/op/label_file_binding.go index 79137ed3..2802f0c0 100644 --- a/internal/op/label_file_binding.go +++ b/internal/op/label_file_binding.go @@ -23,6 +23,7 @@ type CreateLabelFileBinDingReq struct { Type int `json:"type"` HashInfoStr string `json:"hashinfo"` LabelIds string `json:"label_ids"` + LabelIDs []uint64 `json:"labelIdList"` } type ObjLabelResp struct { @@ -54,23 +55,29 @@ func GetLabelByFileName(userId uint, fileName string) ([]model.Label, error) { return labels, nil } +func GetLabelsByFileNamesPublic(fileNames []string) (map[string][]model.Label, error) { + return db.GetLabelsByFileNamesPublic(fileNames) +} + func CreateLabelFileBinDing(req CreateLabelFileBinDingReq, userId uint) error { if err := db.DelLabelFileBinDingByFileName(userId, req.Name); err != nil { return errors.WithMessage(err, "failed del label_file_bin_ding in database") } - if req.LabelIds == "" { + + ids, err := collectLabelIDs(req) + if err != nil { + return err + } + if len(ids) == 0 { return nil } - labelMap := strings.Split(req.LabelIds, ",") - for _, value := range labelMap { - labelId, err := strconv.ParseUint(value, 10, 64) - if err != nil { - return fmt.Errorf("invalid label ID '%s': %v", value, err) - } - if err = db.CreateLabelFileBinDing(req.Name, uint(labelId), userId); err != nil { + + for _, id := range ids { + if err = db.CreateLabelFileBinDing(req.Name, uint(id), userId); err != nil { return errors.WithMessage(err, "failed labels in database") } } + if !db.GetFileByNameExists(req.Name) { objFile := model.ObjFile{ Id: req.Id, @@ -86,8 +93,7 @@ func CreateLabelFileBinDing(req CreateLabelFileBinDingReq, userId uint) error { Type: req.Type, HashInfoStr: req.HashInfoStr, } - err := db.CreateObjFile(objFile) - if err != nil { + if err := db.CreateObjFile(objFile); err != nil { return errors.WithMessage(err, "failed file in database") } } @@ -97,7 +103,7 @@ func CreateLabelFileBinDing(req CreateLabelFileBinDingReq, userId uint) error { func GetFileByLabel(userId uint, labelId string) (result []ObjLabelResp, err error) { labelMap := strings.Split(labelId, ",") var labelIds []uint - var labelsFile []model.LabelFileBinDing + var labelsFile []model.LabelFileBinding var labels []model.Label var labelsFileMap = make(map[string][]model.Label) var labelsMap = make(map[uint]model.Label) @@ -157,3 +163,33 @@ func StringSliceToUintSlice(strSlice []string) ([]uint, error) { } return uintSlice, nil } + +func RestoreLabelFileBindings(bindings []model.LabelFileBinding, keepIDs bool, override bool) error { + return db.RestoreLabelFileBindings(bindings, keepIDs, override) +} + +func collectLabelIDs(req CreateLabelFileBinDingReq) ([]uint64, error) { + if len(req.LabelIDs) > 0 { + return req.LabelIDs, nil + } + s := strings.TrimSpace(req.LabelIds) + if s == "" { + return nil, nil + } + replacer := strings.NewReplacer(",", ",", "、", ",", ";", ",", ";", ",") + s = replacer.Replace(s) + parts := strings.Split(s, ",") + ids := make([]uint64, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + id, err := strconv.ParseUint(p, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid label ID '%s': %v", p, err) + } + ids = append(ids, id) + } + return ids, nil +} diff --git a/server/handles/fsread.go b/server/handles/fsread.go index b49f0b64..cc403c4a 100644 --- a/server/handles/fsread.go +++ b/server/handles/fsread.go @@ -114,7 +114,7 @@ func FsList(c *gin.Context) { provider = storage.GetStorage().Driver } common.SuccessResp(c, FsListResp{ - Content: toObjsResp(objs, reqPath, isEncrypt(meta, reqPath), user.ID), + Content: toObjsResp(objs, reqPath, isEncrypt(meta, reqPath)), Total: int64(total), Readme: getReadme(meta, reqPath), Header: getHeader(meta, reqPath), @@ -224,12 +224,22 @@ func pagination(objs []model.Obj, req *model.PageReq) (int, []model.Obj) { return total, objs[start:end] } -func toObjsResp(objs []model.Obj, parent string, encrypt bool, userId uint) []ObjLabelResp { +func toObjsResp(objs []model.Obj, parent string, encrypt bool) []ObjLabelResp { var resp []ObjLabelResp + + names := make([]string, 0, len(objs)) + for _, obj := range objs { + if !obj.IsDir() { + names = append(names, obj.GetName()) + } + } + + labelsByName, _ := op.GetLabelsByFileNamesPublic(names) + for _, obj := range objs { var labels []model.Label - if obj.IsDir() == false { - labels, _ = op.GetLabelByFileName(userId, obj.GetName()) + if !obj.IsDir() { + labels = labelsByName[obj.GetName()] } thumb, _ := model.GetThumb(obj) resp = append(resp, ObjLabelResp{ @@ -369,7 +379,7 @@ func FsGet(c *gin.Context) { Readme: getReadme(meta, reqPath), Header: getHeader(meta, reqPath), Provider: provider, - Related: toObjsResp(related, parentPath, isEncrypt(parentMeta, parentPath), user.ID), + Related: toObjsResp(related, parentPath, isEncrypt(parentMeta, parentPath)), }) } diff --git a/server/handles/label_file_binding.go b/server/handles/label_file_binding.go index 78af929b..04f0c105 100644 --- a/server/handles/label_file_binding.go +++ b/server/handles/label_file_binding.go @@ -8,7 +8,9 @@ import ( "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/server/common" "github.com/gin-gonic/gin" + "net/url" "strconv" + "strings" ) type DelLabelFileBinDingReq struct { @@ -16,18 +18,36 @@ type DelLabelFileBinDingReq struct { LabelId string `json:"label_id"` } +type pageResp[T any] struct { + Content []T `json:"content"` + Total int64 `json:"total"` +} + +type restoreLabelBindingsReq struct { + KeepIDs bool `json:"keep_ids"` + Override bool `json:"override"` + Bindings []model.LabelFileBinding `json:"bindings"` +} + func GetLabelByFileName(c *gin.Context) { fileName := c.Query("file_name") if fileName == "" { common.ErrorResp(c, errors.New("file_name must not empty"), 400) return } + decodedFileName, err := url.QueryUnescape(fileName) + if err != nil { + common.ErrorResp(c, errors.New("invalid file_name"), 400) + return + } + fmt.Println(">>> 原始 fileName:", fileName) + fmt.Println(">>> 解码后 fileName:", decodedFileName) userObj, ok := c.Value("user").(*model.User) if !ok { common.ErrorStrResp(c, "user invalid", 401) return } - labels, err := op.GetLabelByFileName(userObj.ID, fileName) + labels, err := op.GetLabelByFileName(userObj.ID, decodedFileName) if err != nil { common.ErrorResp(c, err, 500, true) return @@ -101,3 +121,130 @@ func GetFileByLabel(c *gin.Context) { } common.SuccessResp(c, fileList) } + +func ListLabelFileBinding(c *gin.Context) { + userObj, ok := c.Value("user").(*model.User) + if !ok { + common.ErrorStrResp(c, "user invalid", 401) + return + } + + pageStr := c.DefaultQuery("page", "1") + sizeStr := c.DefaultQuery("page_size", "50") + page, err := strconv.Atoi(pageStr) + if err != nil || page <= 0 { + page = 1 + } + pageSize, err := strconv.Atoi(sizeStr) + if err != nil || pageSize <= 0 || pageSize > 200 { + pageSize = 50 + } + + fileName := c.Query("file_name") + labelIDStr := c.Query("label_id") + var labelIDs []uint + if labelIDStr != "" { + parts := strings.Split(labelIDStr, ",") + for _, p := range parts { + if p == "" { + continue + } + id64, err := strconv.ParseUint(strings.TrimSpace(p), 10, 64) + if err != nil { + common.ErrorResp(c, fmt.Errorf("invalid label_id '%s': %v", p, err), 400) + return + } + labelIDs = append(labelIDs, uint(id64)) + } + } + + list, total, err := db.ListLabelFileBinDing(userObj.ID, labelIDs, fileName, page, pageSize) + if err != nil { + common.ErrorResp(c, err, 500, true) + return + } + common.SuccessResp(c, pageResp[model.LabelFileBinding]{ + Content: list, + Total: total, + }) +} + +func RestoreLabelFileBinding(c *gin.Context) { + var req restoreLabelBindingsReq + if err := c.ShouldBindJSON(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + if len(req.Bindings) == 0 { + common.ErrorStrResp(c, "empty bindings", 400) + return + } + + if u, ok := c.Value("user").(*model.User); ok { + for i := range req.Bindings { + if req.Bindings[i].UserId == 0 { + req.Bindings[i].UserId = u.ID + } + } + } + + for i := range req.Bindings { + b := req.Bindings[i] + if b.UserId == 0 || b.LabelId == 0 || strings.TrimSpace(b.FileName) == "" { + common.ErrorStrResp(c, "invalid binding: user_id/label_id/file_name required", 400) + return + } + } + + if err := op.RestoreLabelFileBindings(req.Bindings, req.KeepIDs, req.Override); err != nil { + common.ErrorResp(c, err, 500, true) + return + } + common.SuccessResp(c, gin.H{ + "msg": fmt.Sprintf("restored %d rows", len(req.Bindings)), + }) +} + +func CreateLabelFileBinDingBatch(c *gin.Context) { + var req struct { + Items []op.CreateLabelFileBinDingReq `json:"items" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil || len(req.Items) == 0 { + common.ErrorResp(c, err, 400) + return + } + + userObj, ok := c.Value("user").(*model.User) + if !ok { + common.ErrorStrResp(c, "user invalid", 401) + return + } + + type perResult struct { + Name string `json:"name"` + Ok bool `json:"ok"` + ErrMsg string `json:"errMsg,omitempty"` + } + results := make([]perResult, 0, len(req.Items)) + succeed := 0 + + for _, item := range req.Items { + if item.IsDir { + results = append(results, perResult{Name: item.Name, Ok: false, ErrMsg: "Unable to bind folder"}) + continue + } + if err := op.CreateLabelFileBinDing(item, userObj.ID); err != nil { + results = append(results, perResult{Name: item.Name, Ok: false, ErrMsg: err.Error()}) + continue + } + succeed++ + results = append(results, perResult{Name: item.Name, Ok: true}) + } + + common.SuccessResp(c, gin.H{ + "total": len(req.Items), + "succeed": succeed, + "failed": len(req.Items) - succeed, + "results": results, + }) +} diff --git a/server/handles/user.go b/server/handles/user.go index d5eebba4..b4c152c5 100644 --- a/server/handles/user.go +++ b/server/handles/user.go @@ -67,10 +67,10 @@ func UpdateUser(c *gin.Context) { common.ErrorStrResp(c, "cannot change role of admin user", 403) return } - if user.Username != req.Username { - common.ErrorStrResp(c, "cannot change username of admin user", 403) - return - } + //if user.Username != req.Username { + // common.ErrorStrResp(c, "cannot change username of admin user", 403) + // return + //} } if req.Password == "" { diff --git a/server/router.go b/server/router.go index bf43a625..72546f4e 100644 --- a/server/router.go +++ b/server/router.go @@ -92,6 +92,8 @@ func Init(e *gin.Engine) { _fs(auth.Group("/fs")) _task(auth.Group("/task", middlewares.AuthNotGuest)) + _label(auth.Group("/label")) + _labelFileBinding(auth.Group("/label_file_binding")) admin(auth.Group("/admin", middlewares.AuthAdmin)) if flags.Debug || flags.Dev { debug(g.Group("/debug")) @@ -170,17 +172,17 @@ func admin(g *gin.RouterGroup) { index.GET("/progress", middlewares.SearchIndex, handles.GetProgress) label := g.Group("/label") - label.GET("/list", handles.ListLabel) - label.GET("/get", handles.GetLabel) label.POST("/create", handles.CreateLabel) label.POST("/update", handles.UpdateLabel) label.POST("/delete", handles.DeleteLabel) labelFileBinding := g.Group("/label_file_binding") - labelFileBinding.GET("/get", handles.GetLabelByFileName) - labelFileBinding.GET("/get_file_by_label", handles.GetFileByLabel) + labelFileBinding.GET("/list", handles.ListLabelFileBinding) labelFileBinding.POST("/create", handles.CreateLabelFileBinDing) + labelFileBinding.POST("/create_batch", handles.CreateLabelFileBinDingBatch) labelFileBinding.POST("/delete", handles.DelLabelByFileName) + labelFileBinding.POST("/restore", handles.RestoreLabelFileBinding) + } func _fs(g *gin.RouterGroup) { @@ -216,6 +218,16 @@ func _task(g *gin.RouterGroup) { handles.SetupTaskRoute(g) } +func _label(g *gin.RouterGroup) { + g.GET("/list", handles.ListLabel) + g.GET("/get", handles.GetLabel) +} + +func _labelFileBinding(g *gin.RouterGroup) { + g.GET("/get", handles.GetLabelByFileName) + g.GET("/get_file_by_label", handles.GetFileByLabel) +} + func Cors(r *gin.Engine) { config := cors.DefaultConfig() // config.AllowAllOrigins = true