refactor:separate the user method from the db package to the op package

pull/2747/head
foxxorcat 2022-12-18 12:22:58 +08:00
parent 8a133d70a9
commit fd45fd431d
10 changed files with 150 additions and 107 deletions

View File

@ -4,7 +4,7 @@ Copyright © 2022 NAME HERE <EMAIL ADDRESS>
package cmd
import (
"github.com/alist-org/alist/v3/internal/db"
"github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/spf13/cobra"
)
@ -16,7 +16,7 @@ var passwordCmd = &cobra.Command{
Short: "Show admin user's info",
Run: func(cmd *cobra.Command, args []string) {
Init()
admin, err := db.GetAdmin()
admin, err := op.GetAdmin()
if err != nil {
utils.Log.Errorf("failed get admin user: %+v", err)
} else {

View File

@ -4,7 +4,7 @@ Copyright © 2022 NAME HERE <EMAIL ADDRESS>
package cmd
import (
"github.com/alist-org/alist/v3/internal/db"
"github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/spf13/cobra"
)
@ -15,11 +15,11 @@ var cancel2FACmd = &cobra.Command{
Short: "Delete 2FA of admin user",
Run: func(cmd *cobra.Command, args []string) {
Init()
admin, err := db.GetAdmin()
admin, err := op.GetAdmin()
if err != nil {
utils.Log.Errorf("failed to get admin user: %+v", err)
} else {
err := db.Cancel2FAByUser(admin)
err := op.Cancel2FAByUser(admin)
if err != nil {
utils.Log.Errorf("failed to cancel 2FA: %+v", err)
}

View File

@ -6,6 +6,7 @@ import (
"github.com/alist-org/alist/v3/cmd/flags"
"github.com/alist-org/alist/v3/internal/db"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/alist-org/alist/v3/pkg/utils/random"
"github.com/pkg/errors"
@ -13,7 +14,7 @@ import (
)
func initUser() {
admin, err := db.GetAdmin()
admin, err := op.GetAdmin()
adminPassword := random.String(8)
envpass := os.Getenv("ALIST_ADMIN_PASSWORD")
if flags.Dev {
@ -29,7 +30,7 @@ func initUser() {
Role: model.ADMIN,
BasePath: "/",
}
if err := db.CreateUser(admin); err != nil {
if err := op.CreateUser(admin); err != nil {
panic(err)
} else {
utils.Log.Infof("Successfully created the admin user and the initial password is: %s", admin.Password)
@ -38,7 +39,7 @@ func initUser() {
panic(err)
}
}
guest, err := db.GetGuest()
guest, err := op.GetGuest()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
guest = &model.User{

View File

@ -1,61 +1,24 @@
package db
import (
"time"
"github.com/Xhofe/go-cache"
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/pkg/singleflight"
"github.com/pkg/errors"
)
var userCache = cache.NewMemCache(cache.WithShards[*model.User](2))
var userG singleflight.Group[*model.User]
var guest *model.User
var admin *model.User
func GetAdmin() (*model.User, error) {
if admin != nil {
return admin, nil
}
user := model.User{Role: model.ADMIN}
func GetUserByRole(role int) (*model.User, error) {
user := model.User{Role: role}
if err := db.Where(user).Take(&user).Error; err != nil {
return nil, err
}
admin = &user
return &user, nil
}
func GetGuest() (*model.User, error) {
if guest != nil {
return guest, nil
}
user := model.User{Role: model.GUEST}
if err := db.Where(user).Take(&user).Error; err != nil {
return nil, err
}
guest = &user
return &user, nil
}
func GetUserByName(username string) (*model.User, error) {
if username == "" {
return nil, errors.WithStack(errs.EmptyUsername)
user := model.User{Username: username}
if err := db.Where(user).First(&user).Error; err != nil {
return nil, errors.Wrapf(err, "failed find user")
}
user, ok := userCache.Get(username)
if ok {
return user, nil
}
user, err, _ := userG.Do(username, func() (*model.User, error) {
user := model.User{Username: username}
if err := db.Where(user).First(&user).Error; err != nil {
return nil, errors.Wrapf(err, "failed find user")
}
userCache.Set(username, &user, cache.WithEx[*model.User](time.Hour))
return &user, nil
})
return user, err
return &user, nil
}
func GetUserById(id uint) (*model.User, error) {
@ -71,40 +34,14 @@ func CreateUser(u *model.User) error {
}
func UpdateUser(u *model.User) error {
old, err := GetUserById(u.ID)
if err != nil {
return err
}
userCache.Del(old.Username)
if u.IsGuest() {
guest = nil
}
if u.IsAdmin() {
admin = nil
}
return errors.WithStack(db.Save(u).Error)
}
func Cancel2FAByUser(u *model.User) error {
u.OtpSecret = ""
return errors.WithStack(UpdateUser(u))
}
func Cancel2FAById(id uint) error {
user, err := GetUserById(id)
if err != nil {
return err
}
return Cancel2FAByUser(user)
}
func GetUsers(pageIndex, pageSize int) ([]model.User, int64, error) {
func GetUsers(pageIndex, pageSize int) (users []model.User, count int64, err error) {
userDB := db.Model(&model.User{})
var count int64
if err := userDB.Count(&count).Error; err != nil {
return nil, 0, errors.Wrapf(err, "failed get users count")
}
var users []model.User
if err := userDB.Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&users).Error; err != nil {
return nil, 0, errors.Wrapf(err, "failed get find users")
}
@ -112,13 +49,5 @@ func GetUsers(pageIndex, pageSize int) ([]model.User, int64, error) {
}
func DeleteUserById(id uint) error {
old, err := GetUserById(id)
if err != nil {
return err
}
if old.IsAdmin() || old.IsGuest() {
return errors.WithStack(errs.DeleteAdminOrGuest)
}
userCache.Del(old.Username)
return errors.WithStack(db.Delete(&model.User{}, id).Error)
}

114
internal/op/user.go Normal file
View File

@ -0,0 +1,114 @@
package op
import (
"time"
"github.com/Xhofe/go-cache"
"github.com/alist-org/alist/v3/internal/db"
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/pkg/singleflight"
"github.com/alist-org/alist/v3/pkg/utils"
)
var userCache = cache.NewMemCache(cache.WithShards[*model.User](2))
var userG singleflight.Group[*model.User]
var guestUser *model.User
var adminUser *model.User
func GetAdmin() (*model.User, error) {
if adminUser == nil {
user, err := db.GetUserByRole(model.ADMIN)
if err != nil {
return nil, err
}
adminUser = user
}
return adminUser, nil
}
func GetGuest() (*model.User, error) {
if guestUser == nil {
user, err := db.GetUserByRole(model.GUEST)
if err != nil {
return nil, err
}
guestUser = user
}
return guestUser, nil
}
func GetUserByRole(role int) (*model.User, error) {
return db.GetUserByRole(role)
}
func GetUserByName(username string) (*model.User, error) {
if username == "" {
return nil, errs.EmptyUsername
}
if user, ok := userCache.Get(username); ok {
return user, nil
}
user, err, _ := userG.Do(username, func() (*model.User, error) {
_user, err := db.GetUserByName(username)
if err != nil {
return nil, err
}
userCache.Set(username, _user, cache.WithEx[*model.User](time.Hour))
return _user, nil
})
return user, err
}
func GetUserById(id uint) (*model.User, error) {
return db.GetUserById(id)
}
func GetUsers(pageIndex, pageSize int) (users []model.User, count int64, err error) {
return db.GetUsers(pageIndex, pageSize)
}
func CreateUser(u *model.User) error {
u.BasePath = utils.FixAndCleanPath(u.BasePath)
return db.CreateUser(u)
}
func DeleteUserById(id uint) error {
old, err := db.GetUserById(id)
if err != nil {
return err
}
if old.IsAdmin() || old.IsGuest() {
return errs.DeleteAdminOrGuest
}
return db.DeleteUserById(id)
}
func UpdateUser(u *model.User) error {
old, err := db.GetUserById(u.ID)
if err != nil {
return err
}
if u.IsAdmin() {
adminUser = nil
}
if u.IsGuest() {
guestUser = nil
}
userCache.Del(old.Username)
u.BasePath = utils.FixAndCleanPath(u.BasePath)
return db.UpdateUser(u)
}
func Cancel2FAByUser(u *model.User) error {
u.OtpSecret = ""
return db.UpdateUser(u)
}
func Cancel2FAById(id uint) error {
user, err := db.GetUserById(id)
if err != nil {
return err
}
return Cancel2FAByUser(user)
}

View File

@ -8,7 +8,6 @@ import (
"sync/atomic"
"time"
"github.com/alist-org/alist/v3/internal/db"
"github.com/alist-org/alist/v3/internal/fs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op"
@ -105,7 +104,7 @@ func BuildIndex(ctx context.Context, indexPaths, ignorePaths []string, maxDepth
Quit <- struct{}{}
}
}()
admin, err := db.GetAdmin()
admin, err := op.GetAdmin()
if err != nil {
return err
}

View File

@ -7,8 +7,8 @@ import (
"time"
"github.com/Xhofe/go-cache"
"github.com/alist-org/alist/v3/internal/db"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/server/common"
"github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp"
@ -41,7 +41,7 @@ func Login(c *gin.Context) {
common.ErrorResp(c, err, 400)
return
}
user, err := db.GetUserByName(req.Username)
user, err := op.GetUserByName(req.Username)
if err != nil {
common.ErrorResp(c, err, 400)
loginCache.Set(ip, count+1)
@ -101,7 +101,7 @@ func UpdateCurrent(c *gin.Context) {
if req.Password != "" {
user.Password = req.Password
}
if err := db.UpdateUser(user); err != nil {
if err := op.UpdateUser(user); err != nil {
common.ErrorResp(c, err, 500)
} else {
common.SuccessResp(c)
@ -158,7 +158,7 @@ func Verify2FA(c *gin.Context) {
return
}
user.OtpSecret = req.Secret
if err := db.UpdateUser(user); err != nil {
if err := op.UpdateUser(user); err != nil {
common.ErrorResp(c, err, 500)
} else {
common.SuccessResp(c)

View File

@ -3,8 +3,8 @@ package handles
import (
"strconv"
"github.com/alist-org/alist/v3/internal/db"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/server/common"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
@ -18,7 +18,7 @@ func ListUsers(c *gin.Context) {
}
req.Validate()
log.Debugf("%+v", req)
users, total, err := db.GetUsers(req.Page, req.PerPage)
users, total, err := op.GetUsers(req.Page, req.PerPage)
if err != nil {
common.ErrorResp(c, err, 500, true)
return
@ -39,7 +39,7 @@ func CreateUser(c *gin.Context) {
common.ErrorStrResp(c, "admin or guest user can not be created", 400, true)
return
}
if err := db.CreateUser(&req); err != nil {
if err := op.CreateUser(&req); err != nil {
common.ErrorResp(c, err, 500, true)
} else {
common.SuccessResp(c)
@ -52,7 +52,7 @@ func UpdateUser(c *gin.Context) {
common.ErrorResp(c, err, 400)
return
}
user, err := db.GetUserById(req.ID)
user, err := op.GetUserById(req.ID)
if err != nil {
common.ErrorResp(c, err, 500)
return
@ -67,7 +67,7 @@ func UpdateUser(c *gin.Context) {
if req.OtpSecret == "" {
req.OtpSecret = user.OtpSecret
}
if err := db.UpdateUser(&req); err != nil {
if err := op.UpdateUser(&req); err != nil {
common.ErrorResp(c, err, 500)
} else {
common.SuccessResp(c)
@ -81,7 +81,7 @@ func DeleteUser(c *gin.Context) {
common.ErrorResp(c, err, 400)
return
}
if err := db.DeleteUserById(uint(id)); err != nil {
if err := op.DeleteUserById(uint(id)); err != nil {
common.ErrorResp(c, err, 500)
return
}
@ -95,7 +95,7 @@ func GetUser(c *gin.Context) {
common.ErrorResp(c, err, 400)
return
}
user, err := db.GetUserById(uint(id))
user, err := op.GetUserById(uint(id))
if err != nil {
common.ErrorResp(c, err, 500, true)
return
@ -110,7 +110,7 @@ func Cancel2FAById(c *gin.Context) {
common.ErrorResp(c, err, 400)
return
}
if err := db.Cancel2FAById(uint(id)); err != nil {
if err := op.Cancel2FAById(uint(id)); err != nil {
common.ErrorResp(c, err, 500)
return
}

View File

@ -2,8 +2,8 @@ package middlewares
import (
"github.com/alist-org/alist/v3/internal/conf"
"github.com/alist-org/alist/v3/internal/db"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/internal/setting"
"github.com/alist-org/alist/v3/server/common"
"github.com/gin-gonic/gin"
@ -15,7 +15,7 @@ import (
func Auth(c *gin.Context) {
token := c.GetHeader("Authorization")
if token == setting.GetStr(conf.Token) {
admin, err := db.GetAdmin()
admin, err := op.GetAdmin()
if err != nil {
common.ErrorResp(c, err, 500)
c.Abort()
@ -27,7 +27,7 @@ func Auth(c *gin.Context) {
return
}
if token == "" {
guest, err := db.GetGuest()
guest, err := op.GetGuest()
if err != nil {
common.ErrorResp(c, err, 500)
c.Abort()
@ -44,7 +44,7 @@ func Auth(c *gin.Context) {
c.Abort()
return
}
user, err := db.GetUserByName(userClaims.Username)
user, err := op.GetUserByName(userClaims.Username)
if err != nil {
common.ErrorResp(c, err, 401)
c.Abort()

View File

@ -4,8 +4,8 @@ import (
"context"
"net/http"
"github.com/alist-org/alist/v3/internal/db"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/alist-org/alist/v3/server/webdav"
"github.com/gin-gonic/gin"
@ -45,7 +45,7 @@ func ServeWebDAV(c *gin.Context) {
}
func WebDAVAuth(c *gin.Context) {
guest, _ := db.GetGuest()
guest, _ := op.GetGuest()
username, password, ok := c.Request.BasicAuth()
if !ok {
if c.Request.Method == "OPTIONS" {
@ -58,7 +58,7 @@ func WebDAVAuth(c *gin.Context) {
c.Abort()
return
}
user, err := db.GetUserByName(username)
user, err := op.GetUserByName(username)
if err != nil || user.ValidatePassword(password) != nil {
if c.Request.Method == "OPTIONS" {
c.Set("user", guest)