diff --git a/bootstrap/db.go b/bootstrap/db.go new file mode 100644 index 00000000..ed82ef1d --- /dev/null +++ b/bootstrap/db.go @@ -0,0 +1,39 @@ +package bootstrap + +import ( + "github.com/alist-org/alist/v3/cmd/args" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/store" + log "github.com/sirupsen/logrus" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" + stdlog "log" + "time" +) + +func InitDB() { + newLogger := logger.New( + stdlog.New(log.StandardLogger().Out, "\r\n", stdlog.LstdFlags), + logger.Config{ + SlowThreshold: time.Second, + LogLevel: logger.Silent, + IgnoreRecordNotFoundError: true, + Colorful: true, + }, + ) + gormConfig := &gorm.Config{ + NamingStrategy: schema.NamingStrategy{ + TablePrefix: conf.Conf.Database.TablePrefix, + }, + Logger: newLogger, + } + if args.Dev { + db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), gormConfig) + if err != nil { + panic("failed to connect database") + } + store.Init(db) + } +} diff --git a/bootstrap/log.go b/bootstrap/log.go index b806ede7..d3ccf1d2 100644 --- a/bootstrap/log.go +++ b/bootstrap/log.go @@ -21,7 +21,7 @@ func init() { func Log() { log.SetOutput(logrus.StandardLogger().Out) - if args.Debug { + if args.Debug || args.Dev { logrus.SetLevel(logrus.DebugLevel) logrus.SetReportCaller(true) } diff --git a/cmd/alist.go b/cmd/alist.go index 2da40eb7..04c05f1b 100644 --- a/cmd/alist.go +++ b/cmd/alist.go @@ -20,6 +20,7 @@ func init() { flag.BoolVar(&args.Version, "version", false, "print version info") flag.BoolVar(&args.Password, "password", false, "print current password") flag.BoolVar(&args.NoPrefix, "no-prefix", false, "disable env prefix") + flag.BoolVar(&args.Dev, "dev", false, "start with dev mode") flag.Parse() } @@ -31,6 +32,7 @@ func Init() { } bootstrap.InitConfig() bootstrap.Log() + bootstrap.InitDB() } func main() { Init() diff --git a/cmd/args/config.go b/cmd/args/config.go index 70a99a29..1df93bab 100644 --- a/cmd/args/config.go +++ b/cmd/args/config.go @@ -6,4 +6,5 @@ var ( Version bool Password bool NoPrefix bool + Dev bool ) diff --git a/go.mod b/go.mod index fea49d5a..ea0f7282 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/go-playground/universal-translator v0.18.0 // indirect github.com/go-playground/validator/v10 v10.11.0 // indirect github.com/goccy/go-json v0.9.7 // indirect + github.com/golang-jwt/jwt/v4 v4.4.2 // indirect github.com/google/uuid v1.3.0 // indirect github.com/gorilla/websocket v1.5.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect diff --git a/go.sum b/go.sum index 4676a86d..e3552298 100644 --- a/go.sum +++ b/go.sum @@ -25,6 +25,8 @@ github.com/go-playground/validator/v10 v10.11.0 h1:0W+xRM511GY47Yy3bZUbJVitCNg2B github.com/go-playground/validator/v10 v10.11.0/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU= github.com/goccy/go-json v0.9.7 h1:IcB+Aqpx/iMHu5Yooh7jEzJk1JZ7Pjtmys2ukPr7EeM= github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang-jwt/jwt/v4 v4.4.2 h1:rcc4lwaZgFMCZ5jxF9ABolDcIHdBytAFgqFPbSJQAYs= +github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= diff --git a/internal/conf/config.go b/internal/conf/config.go index 730896a9..4b0d8c52 100644 --- a/internal/conf/config.go +++ b/internal/conf/config.go @@ -1,5 +1,9 @@ package conf +import ( + "github.com/alist-org/alist/v3/pkg/utils/random" +) + type Database struct { Type string `json:"type" env:"DB_TYPE"` Host string `json:"host" env:"DB_HOST"` @@ -30,6 +34,7 @@ type Config struct { Force bool `json:"force"` Address string `json:"address" env:"ADDR"` Port int `json:"port" env:"PORT"` + JwtSecret string `json:"jwt_secret" env:"JWT_SECRET"` CaCheExpiration int `json:"cache_expiration" env:"CACHE_EXPIRATION"` Assets string `json:"assets" env:"ASSETS"` Database Database `json:"database"` @@ -40,10 +45,11 @@ type Config struct { func DefaultConfig() *Config { return &Config{ - Address: "0.0.0.0", - Port: 5244, - Assets: "https://npm.elemecdn.com/alist-web@$version/dist", - TempDir: "data/temp", + Address: "0.0.0.0", + Port: 5244, + JwtSecret: random.RandomStr(16), + Assets: "https://npm.elemecdn.com/alist-web@$version/dist", + TempDir: "data/temp", Database: Database{ Type: "sqlite3", Port: 0, diff --git a/internal/model/user.go b/internal/model/user.go index 69bd4d06..d8bcde33 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -1,5 +1,7 @@ package model +import "github.com/pkg/errors" + const ( GENERAL = iota GUEST // only one exists @@ -7,12 +9,12 @@ const ( ) type User struct { - ID uint `json:"id" gorm:"primaryKey"` // unique key - Name string `json:"name" gorm:"unique"` // username - Password string `json:"password"` // password - BasePath string `json:"base_path"` // base path - ReadOnly bool `json:"read_only"` // allow upload - Role int `json:"role"` // user's role + ID uint `json:"id" gorm:"primaryKey"` // unique key + Username string `json:"username" gorm:"unique"` // username + Password string `json:"password"` // password + BasePath string `json:"base_path"` // base path + ReadOnly bool `json:"read_only"` // allow upload + Role int `json:"role"` // user's role } func (u User) IsGuest() bool { @@ -22,3 +24,13 @@ func (u User) IsGuest() bool { func (u User) IsAdmin() bool { return u.Role == ADMIN } + +func (u User) ValidatePassword(password string) error { + if password == "" { + return errors.New("password is empty") + } + if u.Password != password { + return errors.New("password is incorrect") + } + return nil +} diff --git a/internal/server/common/auth.go b/internal/server/common/auth.go new file mode 100644 index 00000000..bdbb869c --- /dev/null +++ b/internal/server/common/auth.go @@ -0,0 +1,50 @@ +package common + +import ( + "github.com/golang-jwt/jwt/v4" + "github.com/pkg/errors" + "time" +) + +var SecretKey []byte + +type UserClaims struct { + Username string `json:"username"` + jwt.RegisteredClaims +} + +func GenerateToken(username string) (tokenString string, err error) { + claim := UserClaims{ + Username: username, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(12 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + }} + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claim) + tokenString, err = token.SignedString(SecretKey) + return tokenString, err +} + +func ParseToken(tokenString string) (*UserClaims, error) { + token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { + return SecretKey, nil + }) + if err != nil { + if ve, ok := err.(*jwt.ValidationError); ok { + if ve.Errors&jwt.ValidationErrorMalformed != 0 { + return nil, errors.New("that's not even a token") + } else if ve.Errors&jwt.ValidationErrorExpired != 0 { + return nil, errors.New("token is expired") + } else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 { + return nil, errors.New("token not active yet") + } else { + return nil, errors.New("couldn't handle this token") + } + } + } + if claims, ok := token.Claims.(*UserClaims); ok && token.Valid { + return claims, nil + } + return nil, errors.New("couldn't handle this token") +} diff --git a/internal/server/common/common.go b/internal/server/common/common.go new file mode 100644 index 00000000..475d3159 --- /dev/null +++ b/internal/server/common/common.go @@ -0,0 +1,48 @@ +package common + +import ( + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" +) + +type Resp struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data"` +} + +func ErrorResp(c *gin.Context, err error, code int) { + log.Error(err.Error()) + c.JSON(200, Resp{ + Code: code, + Message: err.Error(), + Data: nil, + }) + c.Abort() +} + +func ErrorStrResp(c *gin.Context, str string, code int) { + log.Error(str) + c.JSON(200, Resp{ + Code: code, + Message: str, + Data: nil, + }) + c.Abort() +} + +func SuccessResp(c *gin.Context, data ...interface{}) { + if len(data) == 0 { + c.JSON(200, Resp{ + Code: 200, + Message: "success", + Data: nil, + }) + return + } + c.JSON(200, Resp{ + Code: 200, + Message: "success", + Data: data[0], + }) +} diff --git a/internal/server/controllers/login.go b/internal/server/controllers/login.go new file mode 100644 index 00000000..fd2187f5 --- /dev/null +++ b/internal/server/controllers/login.go @@ -0,0 +1,56 @@ +package controllers + +import ( + "github.com/Xhofe/go-cache" + "github.com/alist-org/alist/v3/internal/server/common" + "github.com/alist-org/alist/v3/internal/store" + "github.com/gin-gonic/gin" + "time" +) + +var loginCache = cache.NewMemCache[int]() +var ( + defaultDuration = time.Minute * 5 + defaultTimes = 5 +) + +type LoginReq struct { + Username string `json:"username"` + Password string `json:"password"` +} + +func Login(c *gin.Context) { + // check count of login + ip := c.ClientIP() + count, ok := loginCache.Get(ip) + if ok && count > defaultTimes { + common.ErrorStrResp(c, "Too many unsuccessful sign-in attempts have been made using an incorrect password. Try again later.", 403) + loginCache.Expire(ip, defaultDuration) + return + } + // check username + var req LoginReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + user, err := store.GetUserByName(req.Username) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + // validate password + if err := user.ValidatePassword(req.Password); err != nil { + common.ErrorResp(c, err, 400) + loginCache.Set(ip, count+1) + return + } + // generate token + token, err := common.GenerateToken(user.Username) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + common.SuccessResp(c, gin.H{"token": token}) + loginCache.Del(ip) +} diff --git a/internal/server/controllers/user.go b/internal/server/controllers/user.go deleted file mode 100644 index 2d329367..00000000 --- a/internal/server/controllers/user.go +++ /dev/null @@ -1 +0,0 @@ -package controllers diff --git a/internal/server/middlewares/auth.go b/internal/server/middlewares/auth.go new file mode 100644 index 00000000..199db64d --- /dev/null +++ b/internal/server/middlewares/auth.go @@ -0,0 +1,25 @@ +package middlewares + +import ( + "github.com/alist-org/alist/v3/internal/server/common" + "github.com/alist-org/alist/v3/internal/store" + "github.com/gin-gonic/gin" +) + +func AuthAdmin(c *gin.Context) { + token := c.GetHeader("Authorization") + userClaims, err := common.ParseToken(token) + if err != nil { + common.ErrorResp(c, err, 401) + c.Abort() + return + } + user, err := store.GetUserByName(userClaims.Username) + if err != nil { + common.ErrorResp(c, err, 401) + c.Abort() + return + } + c.Set("user", user) + c.Next() +} diff --git a/internal/server/router.go b/internal/server/router.go index 8a0ec218..ecf58517 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -1,12 +1,19 @@ package server import ( + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/server/common" + "github.com/alist-org/alist/v3/internal/server/controllers" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" ) func Init(r *gin.Engine) { + common.SecretKey = []byte(conf.Conf.JwtSecret) Cors(r) + + api := r.Group("/api") + api.POST("/user/login", controllers.Login) } func Cors(r *gin.Engine) { diff --git a/internal/store/user.go b/internal/store/user.go index 01efc002..1219de8c 100644 --- a/internal/store/user.go +++ b/internal/store/user.go @@ -18,17 +18,17 @@ func ExistGuest() bool { return db.Take(&model.User{Role: model.GUEST}).Error != nil } -func GetUserByName(name string) (*model.User, error) { - user, ok := userCache.Get(name) +func GetUserByName(username string) (*model.User, error) { + user, ok := userCache.Get(username) if ok { return user, nil } - user, err, _ := userG.Do(name, func() (*model.User, error) { - user := model.User{Name: name} + user, err, _ := userG.Do(username, func() (*model.User, error) { + user := model.User{Username: username} if err := db.Where(user).First(&user).Error; err != nil { - return nil, errors.Wrapf(err, "failed select user") + return nil, errors.Wrapf(err, "failed find user") } - userCache.Set(name, &user) + userCache.Set(username, &user) return &user, nil }) return user, err @@ -51,7 +51,7 @@ func UpdateUser(u *model.User) error { if err != nil { return err } - userCache.Del(old.Name) + userCache.Del(old.Username) return errors.WithStack(db.Save(u).Error) } @@ -73,6 +73,6 @@ func DeleteUserById(id uint) error { if err != nil { return err } - userCache.Del(old.Name) + userCache.Del(old.Username) return errors.WithStack(db.Delete(&model.User{}, id).Error) } diff --git a/pkg/utils/random.go b/pkg/utils/random/random.go similarity index 96% rename from pkg/utils/random.go rename to pkg/utils/random/random.go index 5bd086c5..0935104b 100644 --- a/pkg/utils/random.go +++ b/pkg/utils/random/random.go @@ -1,4 +1,4 @@ -package utils +package random import ( "math/rand"