diff --git a/cmd/admin.go b/cmd/admin.go index 6f61ccab..62ca0f67 100644 --- a/cmd/admin.go +++ b/cmd/admin.go @@ -4,7 +4,7 @@ Copyright © 2022 NAME HERE package cmd import ( - "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/pkg/utils" "github.com/spf13/cobra" ) @@ -16,7 +16,7 @@ var passwordCmd = &cobra.Command{ Short: "Show admin user's info", Run: func(cmd *cobra.Command, args []string) { Init() - admin, err := db.GetAdmin() + admin, err := op.GetAdmin() if err != nil { utils.Log.Errorf("failed get admin user: %+v", err) } else { diff --git a/cmd/cancel2FA.go b/cmd/cancel2FA.go index d90a398e..46f7e81d 100644 --- a/cmd/cancel2FA.go +++ b/cmd/cancel2FA.go @@ -4,7 +4,7 @@ Copyright © 2022 NAME HERE package cmd import ( - "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/pkg/utils" "github.com/spf13/cobra" ) @@ -15,11 +15,11 @@ var cancel2FACmd = &cobra.Command{ Short: "Delete 2FA of admin user", Run: func(cmd *cobra.Command, args []string) { Init() - admin, err := db.GetAdmin() + admin, err := op.GetAdmin() if err != nil { utils.Log.Errorf("failed to get admin user: %+v", err) } else { - err := db.Cancel2FAByUser(admin) + err := op.Cancel2FAByUser(admin) if err != nil { utils.Log.Errorf("failed to cancel 2FA: %+v", err) } diff --git a/internal/bootstrap/data/user.go b/internal/bootstrap/data/user.go index ac6ca1cd..04018ee0 100644 --- a/internal/bootstrap/data/user.go +++ b/internal/bootstrap/data/user.go @@ -6,6 +6,7 @@ import ( "github.com/alist-org/alist/v3/cmd/flags" "github.com/alist-org/alist/v3/internal/db" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/pkg/utils/random" "github.com/pkg/errors" @@ -13,7 +14,7 @@ import ( ) func initUser() { - admin, err := db.GetAdmin() + admin, err := op.GetAdmin() adminPassword := random.String(8) envpass := os.Getenv("ALIST_ADMIN_PASSWORD") if flags.Dev { @@ -29,7 +30,7 @@ func initUser() { Role: model.ADMIN, BasePath: "/", } - if err := db.CreateUser(admin); err != nil { + if err := op.CreateUser(admin); err != nil { panic(err) } else { utils.Log.Infof("Successfully created the admin user and the initial password is: %s", admin.Password) @@ -38,7 +39,7 @@ func initUser() { panic(err) } } - guest, err := db.GetGuest() + guest, err := op.GetGuest() if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { guest = &model.User{ diff --git a/internal/db/user.go b/internal/db/user.go index 04bdfbaf..cc2d817f 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -1,61 +1,24 @@ package db import ( - "time" - - "github.com/Xhofe/go-cache" - "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" - "github.com/alist-org/alist/v3/pkg/singleflight" "github.com/pkg/errors" ) -var userCache = cache.NewMemCache(cache.WithShards[*model.User](2)) -var userG singleflight.Group[*model.User] -var guest *model.User -var admin *model.User - -func GetAdmin() (*model.User, error) { - if admin != nil { - return admin, nil - } - user := model.User{Role: model.ADMIN} +func GetUserByRole(role int) (*model.User, error) { + user := model.User{Role: role} if err := db.Where(user).Take(&user).Error; err != nil { return nil, err } - admin = &user - return &user, nil -} - -func GetGuest() (*model.User, error) { - if guest != nil { - return guest, nil - } - user := model.User{Role: model.GUEST} - if err := db.Where(user).Take(&user).Error; err != nil { - return nil, err - } - guest = &user return &user, nil } func GetUserByName(username string) (*model.User, error) { - if username == "" { - return nil, errors.WithStack(errs.EmptyUsername) + user := model.User{Username: username} + if err := db.Where(user).First(&user).Error; err != nil { + return nil, errors.Wrapf(err, "failed find user") } - user, ok := userCache.Get(username) - if ok { - return user, nil - } - user, err, _ := userG.Do(username, func() (*model.User, error) { - user := model.User{Username: username} - if err := db.Where(user).First(&user).Error; err != nil { - return nil, errors.Wrapf(err, "failed find user") - } - userCache.Set(username, &user, cache.WithEx[*model.User](time.Hour)) - return &user, nil - }) - return user, err + return &user, nil } func GetUserById(id uint) (*model.User, error) { @@ -71,40 +34,14 @@ func CreateUser(u *model.User) error { } func UpdateUser(u *model.User) error { - old, err := GetUserById(u.ID) - if err != nil { - return err - } - userCache.Del(old.Username) - if u.IsGuest() { - guest = nil - } - if u.IsAdmin() { - admin = nil - } return errors.WithStack(db.Save(u).Error) } -func Cancel2FAByUser(u *model.User) error { - u.OtpSecret = "" - return errors.WithStack(UpdateUser(u)) -} - -func Cancel2FAById(id uint) error { - user, err := GetUserById(id) - if err != nil { - return err - } - return Cancel2FAByUser(user) -} - -func GetUsers(pageIndex, pageSize int) ([]model.User, int64, error) { +func GetUsers(pageIndex, pageSize int) (users []model.User, count int64, err error) { userDB := db.Model(&model.User{}) - var count int64 if err := userDB.Count(&count).Error; err != nil { return nil, 0, errors.Wrapf(err, "failed get users count") } - var users []model.User if err := userDB.Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&users).Error; err != nil { return nil, 0, errors.Wrapf(err, "failed get find users") } @@ -112,13 +49,5 @@ func GetUsers(pageIndex, pageSize int) ([]model.User, int64, error) { } func DeleteUserById(id uint) error { - old, err := GetUserById(id) - if err != nil { - return err - } - if old.IsAdmin() || old.IsGuest() { - return errors.WithStack(errs.DeleteAdminOrGuest) - } - userCache.Del(old.Username) return errors.WithStack(db.Delete(&model.User{}, id).Error) } diff --git a/internal/op/user.go b/internal/op/user.go new file mode 100644 index 00000000..e98a378f --- /dev/null +++ b/internal/op/user.go @@ -0,0 +1,114 @@ +package op + +import ( + "time" + + "github.com/Xhofe/go-cache" + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/singleflight" + "github.com/alist-org/alist/v3/pkg/utils" +) + +var userCache = cache.NewMemCache(cache.WithShards[*model.User](2)) +var userG singleflight.Group[*model.User] +var guestUser *model.User +var adminUser *model.User + +func GetAdmin() (*model.User, error) { + if adminUser == nil { + user, err := db.GetUserByRole(model.ADMIN) + if err != nil { + return nil, err + } + adminUser = user + } + return adminUser, nil +} + +func GetGuest() (*model.User, error) { + if guestUser == nil { + user, err := db.GetUserByRole(model.GUEST) + if err != nil { + return nil, err + } + guestUser = user + } + return guestUser, nil +} + +func GetUserByRole(role int) (*model.User, error) { + return db.GetUserByRole(role) +} + +func GetUserByName(username string) (*model.User, error) { + if username == "" { + return nil, errs.EmptyUsername + } + if user, ok := userCache.Get(username); ok { + return user, nil + } + user, err, _ := userG.Do(username, func() (*model.User, error) { + _user, err := db.GetUserByName(username) + if err != nil { + return nil, err + } + userCache.Set(username, _user, cache.WithEx[*model.User](time.Hour)) + return _user, nil + }) + return user, err +} + +func GetUserById(id uint) (*model.User, error) { + return db.GetUserById(id) +} + +func GetUsers(pageIndex, pageSize int) (users []model.User, count int64, err error) { + return db.GetUsers(pageIndex, pageSize) +} + +func CreateUser(u *model.User) error { + u.BasePath = utils.FixAndCleanPath(u.BasePath) + return db.CreateUser(u) +} + +func DeleteUserById(id uint) error { + old, err := db.GetUserById(id) + if err != nil { + return err + } + if old.IsAdmin() || old.IsGuest() { + return errs.DeleteAdminOrGuest + } + return db.DeleteUserById(id) +} + +func UpdateUser(u *model.User) error { + old, err := db.GetUserById(u.ID) + if err != nil { + return err + } + if u.IsAdmin() { + adminUser = nil + } + if u.IsGuest() { + guestUser = nil + } + userCache.Del(old.Username) + u.BasePath = utils.FixAndCleanPath(u.BasePath) + return db.UpdateUser(u) +} + +func Cancel2FAByUser(u *model.User) error { + u.OtpSecret = "" + return db.UpdateUser(u) +} + +func Cancel2FAById(id uint) error { + user, err := db.GetUserById(id) + if err != nil { + return err + } + return Cancel2FAByUser(user) +} diff --git a/internal/search/build.go b/internal/search/build.go index 19c4a806..53b944f7 100644 --- a/internal/search/build.go +++ b/internal/search/build.go @@ -8,7 +8,6 @@ import ( "sync/atomic" "time" - "github.com/alist-org/alist/v3/internal/db" "github.com/alist-org/alist/v3/internal/fs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" @@ -105,7 +104,7 @@ func BuildIndex(ctx context.Context, indexPaths, ignorePaths []string, maxDepth Quit <- struct{}{} } }() - admin, err := db.GetAdmin() + admin, err := op.GetAdmin() if err != nil { return err } diff --git a/server/handles/auth.go b/server/handles/auth.go index 0ecd52d8..920b7d75 100644 --- a/server/handles/auth.go +++ b/server/handles/auth.go @@ -7,8 +7,8 @@ import ( "time" "github.com/Xhofe/go-cache" - "github.com/alist-org/alist/v3/internal/db" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/server/common" "github.com/gin-gonic/gin" "github.com/pquerna/otp/totp" @@ -41,7 +41,7 @@ func Login(c *gin.Context) { common.ErrorResp(c, err, 400) return } - user, err := db.GetUserByName(req.Username) + user, err := op.GetUserByName(req.Username) if err != nil { common.ErrorResp(c, err, 400) loginCache.Set(ip, count+1) @@ -101,7 +101,7 @@ func UpdateCurrent(c *gin.Context) { if req.Password != "" { user.Password = req.Password } - if err := db.UpdateUser(user); err != nil { + if err := op.UpdateUser(user); err != nil { common.ErrorResp(c, err, 500) } else { common.SuccessResp(c) @@ -158,7 +158,7 @@ func Verify2FA(c *gin.Context) { return } user.OtpSecret = req.Secret - if err := db.UpdateUser(user); err != nil { + if err := op.UpdateUser(user); err != nil { common.ErrorResp(c, err, 500) } else { common.SuccessResp(c) diff --git a/server/handles/user.go b/server/handles/user.go index cbee8f33..22d4e87f 100644 --- a/server/handles/user.go +++ b/server/handles/user.go @@ -3,8 +3,8 @@ package handles import ( "strconv" - "github.com/alist-org/alist/v3/internal/db" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/server/common" "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" @@ -18,7 +18,7 @@ func ListUsers(c *gin.Context) { } req.Validate() log.Debugf("%+v", req) - users, total, err := db.GetUsers(req.Page, req.PerPage) + users, total, err := op.GetUsers(req.Page, req.PerPage) if err != nil { common.ErrorResp(c, err, 500, true) return @@ -39,7 +39,7 @@ func CreateUser(c *gin.Context) { common.ErrorStrResp(c, "admin or guest user can not be created", 400, true) return } - if err := db.CreateUser(&req); err != nil { + if err := op.CreateUser(&req); err != nil { common.ErrorResp(c, err, 500, true) } else { common.SuccessResp(c) @@ -52,7 +52,7 @@ func UpdateUser(c *gin.Context) { common.ErrorResp(c, err, 400) return } - user, err := db.GetUserById(req.ID) + user, err := op.GetUserById(req.ID) if err != nil { common.ErrorResp(c, err, 500) return @@ -67,7 +67,7 @@ func UpdateUser(c *gin.Context) { if req.OtpSecret == "" { req.OtpSecret = user.OtpSecret } - if err := db.UpdateUser(&req); err != nil { + if err := op.UpdateUser(&req); err != nil { common.ErrorResp(c, err, 500) } else { common.SuccessResp(c) @@ -81,7 +81,7 @@ func DeleteUser(c *gin.Context) { common.ErrorResp(c, err, 400) return } - if err := db.DeleteUserById(uint(id)); err != nil { + if err := op.DeleteUserById(uint(id)); err != nil { common.ErrorResp(c, err, 500) return } @@ -95,7 +95,7 @@ func GetUser(c *gin.Context) { common.ErrorResp(c, err, 400) return } - user, err := db.GetUserById(uint(id)) + user, err := op.GetUserById(uint(id)) if err != nil { common.ErrorResp(c, err, 500, true) return @@ -110,7 +110,7 @@ func Cancel2FAById(c *gin.Context) { common.ErrorResp(c, err, 400) return } - if err := db.Cancel2FAById(uint(id)); err != nil { + if err := op.Cancel2FAById(uint(id)); err != nil { common.ErrorResp(c, err, 500) return } diff --git a/server/middlewares/auth.go b/server/middlewares/auth.go index c22f2825..7ad40425 100644 --- a/server/middlewares/auth.go +++ b/server/middlewares/auth.go @@ -2,8 +2,8 @@ package middlewares import ( "github.com/alist-org/alist/v3/internal/conf" - "github.com/alist-org/alist/v3/internal/db" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/internal/setting" "github.com/alist-org/alist/v3/server/common" "github.com/gin-gonic/gin" @@ -15,7 +15,7 @@ import ( func Auth(c *gin.Context) { token := c.GetHeader("Authorization") if token == setting.GetStr(conf.Token) { - admin, err := db.GetAdmin() + admin, err := op.GetAdmin() if err != nil { common.ErrorResp(c, err, 500) c.Abort() @@ -27,7 +27,7 @@ func Auth(c *gin.Context) { return } if token == "" { - guest, err := db.GetGuest() + guest, err := op.GetGuest() if err != nil { common.ErrorResp(c, err, 500) c.Abort() @@ -44,7 +44,7 @@ func Auth(c *gin.Context) { c.Abort() return } - user, err := db.GetUserByName(userClaims.Username) + user, err := op.GetUserByName(userClaims.Username) if err != nil { common.ErrorResp(c, err, 401) c.Abort() diff --git a/server/webdav.go b/server/webdav.go index ddd95028..6c97266e 100644 --- a/server/webdav.go +++ b/server/webdav.go @@ -4,8 +4,8 @@ import ( "context" "net/http" - "github.com/alist-org/alist/v3/internal/db" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/webdav" "github.com/gin-gonic/gin" @@ -45,7 +45,7 @@ func ServeWebDAV(c *gin.Context) { } func WebDAVAuth(c *gin.Context) { - guest, _ := db.GetGuest() + guest, _ := op.GetGuest() username, password, ok := c.Request.BasicAuth() if !ok { if c.Request.Method == "OPTIONS" { @@ -58,7 +58,7 @@ func WebDAVAuth(c *gin.Context) { c.Abort() return } - user, err := db.GetUserByName(username) + user, err := op.GetUserByName(username) if err != nil || user.ValidatePassword(password) != nil { if c.Request.Method == "OPTIONS" { c.Set("user", guest)