diff --git a/bootstrap/data.go b/bootstrap/data.go new file mode 100644 index 00000000..589d12a2 --- /dev/null +++ b/bootstrap/data.go @@ -0,0 +1,49 @@ +package bootstrap + +import ( + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils/random" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "gorm.io/gorm" +) + +func InitData() { + initUser() +} + +func initUser() { + admin, err := db.GetAdmin() + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + admin = &model.User{ + Username: "admin", + Password: random.RandomStr(8), + Role: model.ADMIN, + BasePath: "/", + } + if err := db.CreateUser(admin); err != nil { + panic(err) + } + } else { + panic(err) + } + } + guest, err := db.GetGuest() + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + guest = &model.User{ + Username: "guest", + Role: model.GUEST, + BasePath: "/", + } + if err := db.CreateUser(guest); err != nil { + panic(err) + } + } else { + panic(err) + } + } + log.Infof("admin password: %+v", admin.Password) +} diff --git a/cmd/alist.go b/cmd/alist.go index 04c05f1b..f7f6cc53 100644 --- a/cmd/alist.go +++ b/cmd/alist.go @@ -33,10 +33,11 @@ func Init() { bootstrap.InitConfig() bootstrap.Log() bootstrap.InitDB() + bootstrap.InitData() } func main() { Init() - if !args.Debug { + if !args.Debug && !args.Dev { gin.SetMode(gin.ReleaseMode) } r := gin.New() diff --git a/internal/db/user.go b/internal/db/user.go index b41d199d..ee2c23a4 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -2,6 +2,7 @@ package db import ( "github.com/Xhofe/go-cache" + "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/pkg/singleflight" "github.com/pkg/errors" @@ -10,15 +11,26 @@ import ( var userCache = cache.NewMemCache(cache.WithShards[*model.User](2)) var userG singleflight.Group[*model.User] -func ExistAdmin() bool { - return db.Take(&model.User{Role: model.ADMIN}).Error != nil +func GetAdmin() (*model.User, error) { + user := model.User{Role: model.ADMIN} + if err := db.Where(user).Take(&user).Error; err != nil { + return nil, err + } + return &user, nil } -func ExistGuest() bool { - return db.Take(&model.User{Role: model.GUEST}).Error != nil +func GetGuest() (*model.User, error) { + user := model.User{Role: model.GUEST} + if err := db.Where(user).Take(&user).Error; err != nil { + return nil, err + } + return &user, nil } func GetUserByName(username string) (*model.User, error) { + if username == "" { + return nil, errors.WithStack(errs.EmptyUsername) + } user, ok := userCache.Get(username) if ok { return user, nil diff --git a/internal/errs/user.go b/internal/errs/user.go new file mode 100644 index 00000000..5c577a76 --- /dev/null +++ b/internal/errs/user.go @@ -0,0 +1,9 @@ +package errs + +import "errors" + +var ( + EmptyUsername = errors.New("username is empty") + EmptyPassword = errors.New("password is empty") + WrongPassword = errors.New("password is incorrect") +) diff --git a/internal/model/user.go b/internal/model/user.go index d8bcde33..9122c296 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -1,6 +1,9 @@ package model -import "github.com/pkg/errors" +import ( + "github.com/alist-org/alist/v3/internal/errs" + "github.com/pkg/errors" +) const ( GENERAL = iota @@ -27,10 +30,10 @@ func (u User) IsAdmin() bool { func (u User) ValidatePassword(password string) error { if password == "" { - return errors.New("password is empty") + return errors.WithStack(errs.EmptyPassword) } if u.Password != password { - return errors.New("password is incorrect") + return errors.WithStack(errs.WrongPassword) } return nil } diff --git a/internal/server/common/common.go b/internal/server/common/common.go index 475d3159..06d13d00 100644 --- a/internal/server/common/common.go +++ b/internal/server/common/common.go @@ -11,8 +11,10 @@ type Resp struct { Data interface{} `json:"data"` } -func ErrorResp(c *gin.Context, err error, code int) { - log.Error(err.Error()) +func ErrorResp(c *gin.Context, err error, code int, noLog ...bool) { + if len(noLog) != 0 && noLog[0] { + log.Errorf("%+v", err) + } c.JSON(200, Resp{ Code: code, Message: err.Error(), diff --git a/internal/server/controllers/login.go b/internal/server/controllers/login.go index 599f8cc0..92285c59 100644 --- a/internal/server/controllers/login.go +++ b/internal/server/controllers/login.go @@ -36,12 +36,12 @@ func Login(c *gin.Context) { } user, err := db.GetUserByName(req.Username) if err != nil { - common.ErrorResp(c, err, 400) + common.ErrorResp(c, err, 400, true) return } // validate password if err := user.ValidatePassword(req.Password); err != nil { - common.ErrorResp(c, err, 400) + common.ErrorResp(c, err, 400, true) loginCache.Set(ip, count+1) return }