mirror of https://github.com/Xhofe/alist
feat: user jwt login
parent
306b90399c
commit
c5295f4d72
|
@ -0,0 +1,39 @@
|
|||
package bootstrap
|
||||
|
||||
import (
|
||||
"github.com/alist-org/alist/v3/cmd/args"
|
||||
"github.com/alist-org/alist/v3/internal/conf"
|
||||
"github.com/alist-org/alist/v3/internal/store"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
stdlog "log"
|
||||
"time"
|
||||
)
|
||||
|
||||
func InitDB() {
|
||||
newLogger := logger.New(
|
||||
stdlog.New(log.StandardLogger().Out, "\r\n", stdlog.LstdFlags),
|
||||
logger.Config{
|
||||
SlowThreshold: time.Second,
|
||||
LogLevel: logger.Silent,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
Colorful: true,
|
||||
},
|
||||
)
|
||||
gormConfig := &gorm.Config{
|
||||
NamingStrategy: schema.NamingStrategy{
|
||||
TablePrefix: conf.Conf.Database.TablePrefix,
|
||||
},
|
||||
Logger: newLogger,
|
||||
}
|
||||
if args.Dev {
|
||||
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), gormConfig)
|
||||
if err != nil {
|
||||
panic("failed to connect database")
|
||||
}
|
||||
store.Init(db)
|
||||
}
|
||||
}
|
|
@ -21,7 +21,7 @@ func init() {
|
|||
|
||||
func Log() {
|
||||
log.SetOutput(logrus.StandardLogger().Out)
|
||||
if args.Debug {
|
||||
if args.Debug || args.Dev {
|
||||
logrus.SetLevel(logrus.DebugLevel)
|
||||
logrus.SetReportCaller(true)
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ func init() {
|
|||
flag.BoolVar(&args.Version, "version", false, "print version info")
|
||||
flag.BoolVar(&args.Password, "password", false, "print current password")
|
||||
flag.BoolVar(&args.NoPrefix, "no-prefix", false, "disable env prefix")
|
||||
flag.BoolVar(&args.Dev, "dev", false, "start with dev mode")
|
||||
flag.Parse()
|
||||
}
|
||||
|
||||
|
@ -31,6 +32,7 @@ func Init() {
|
|||
}
|
||||
bootstrap.InitConfig()
|
||||
bootstrap.Log()
|
||||
bootstrap.InitDB()
|
||||
}
|
||||
func main() {
|
||||
Init()
|
||||
|
|
|
@ -6,4 +6,5 @@ var (
|
|||
Version bool
|
||||
Password bool
|
||||
NoPrefix bool
|
||||
Dev bool
|
||||
)
|
||||
|
|
1
go.mod
1
go.mod
|
@ -21,6 +21,7 @@ require (
|
|||
github.com/go-playground/universal-translator v0.18.0 // indirect
|
||||
github.com/go-playground/validator/v10 v10.11.0 // indirect
|
||||
github.com/goccy/go-json v0.9.7 // indirect
|
||||
github.com/golang-jwt/jwt/v4 v4.4.2 // indirect
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.0 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
|
|
2
go.sum
2
go.sum
|
@ -25,6 +25,8 @@ github.com/go-playground/validator/v10 v10.11.0 h1:0W+xRM511GY47Yy3bZUbJVitCNg2B
|
|||
github.com/go-playground/validator/v10 v10.11.0/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU=
|
||||
github.com/goccy/go-json v0.9.7 h1:IcB+Aqpx/iMHu5Yooh7jEzJk1JZ7Pjtmys2ukPr7EeM=
|
||||
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang-jwt/jwt/v4 v4.4.2 h1:rcc4lwaZgFMCZ5jxF9ABolDcIHdBytAFgqFPbSJQAYs=
|
||||
github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
package conf
|
||||
|
||||
import (
|
||||
"github.com/alist-org/alist/v3/pkg/utils/random"
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
Type string `json:"type" env:"DB_TYPE"`
|
||||
Host string `json:"host" env:"DB_HOST"`
|
||||
|
@ -30,6 +34,7 @@ type Config struct {
|
|||
Force bool `json:"force"`
|
||||
Address string `json:"address" env:"ADDR"`
|
||||
Port int `json:"port" env:"PORT"`
|
||||
JwtSecret string `json:"jwt_secret" env:"JWT_SECRET"`
|
||||
CaCheExpiration int `json:"cache_expiration" env:"CACHE_EXPIRATION"`
|
||||
Assets string `json:"assets" env:"ASSETS"`
|
||||
Database Database `json:"database"`
|
||||
|
@ -40,10 +45,11 @@ type Config struct {
|
|||
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Address: "0.0.0.0",
|
||||
Port: 5244,
|
||||
Assets: "https://npm.elemecdn.com/alist-web@$version/dist",
|
||||
TempDir: "data/temp",
|
||||
Address: "0.0.0.0",
|
||||
Port: 5244,
|
||||
JwtSecret: random.RandomStr(16),
|
||||
Assets: "https://npm.elemecdn.com/alist-web@$version/dist",
|
||||
TempDir: "data/temp",
|
||||
Database: Database{
|
||||
Type: "sqlite3",
|
||||
Port: 0,
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
package model
|
||||
|
||||
import "github.com/pkg/errors"
|
||||
|
||||
const (
|
||||
GENERAL = iota
|
||||
GUEST // only one exists
|
||||
|
@ -7,12 +9,12 @@ const (
|
|||
)
|
||||
|
||||
type User struct {
|
||||
ID uint `json:"id" gorm:"primaryKey"` // unique key
|
||||
Name string `json:"name" gorm:"unique"` // username
|
||||
Password string `json:"password"` // password
|
||||
BasePath string `json:"base_path"` // base path
|
||||
ReadOnly bool `json:"read_only"` // allow upload
|
||||
Role int `json:"role"` // user's role
|
||||
ID uint `json:"id" gorm:"primaryKey"` // unique key
|
||||
Username string `json:"username" gorm:"unique"` // username
|
||||
Password string `json:"password"` // password
|
||||
BasePath string `json:"base_path"` // base path
|
||||
ReadOnly bool `json:"read_only"` // allow upload
|
||||
Role int `json:"role"` // user's role
|
||||
}
|
||||
|
||||
func (u User) IsGuest() bool {
|
||||
|
@ -22,3 +24,13 @@ func (u User) IsGuest() bool {
|
|||
func (u User) IsAdmin() bool {
|
||||
return u.Role == ADMIN
|
||||
}
|
||||
|
||||
func (u User) ValidatePassword(password string) error {
|
||||
if password == "" {
|
||||
return errors.New("password is empty")
|
||||
}
|
||||
if u.Password != password {
|
||||
return errors.New("password is incorrect")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/pkg/errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
var SecretKey []byte
|
||||
|
||||
type UserClaims struct {
|
||||
Username string `json:"username"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
func GenerateToken(username string) (tokenString string, err error) {
|
||||
claim := UserClaims{
|
||||
Username: username,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(12 * time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
}}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claim)
|
||||
tokenString, err = token.SignedString(SecretKey)
|
||||
return tokenString, err
|
||||
}
|
||||
|
||||
func ParseToken(tokenString string) (*UserClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return SecretKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
if ve, ok := err.(*jwt.ValidationError); ok {
|
||||
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
|
||||
return nil, errors.New("that's not even a token")
|
||||
} else if ve.Errors&jwt.ValidationErrorExpired != 0 {
|
||||
return nil, errors.New("token is expired")
|
||||
} else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 {
|
||||
return nil, errors.New("token not active yet")
|
||||
} else {
|
||||
return nil, errors.New("couldn't handle this token")
|
||||
}
|
||||
}
|
||||
}
|
||||
if claims, ok := token.Claims.(*UserClaims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
return nil, errors.New("couldn't handle this token")
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Resp struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
func ErrorResp(c *gin.Context, err error, code int) {
|
||||
log.Error(err.Error())
|
||||
c.JSON(200, Resp{
|
||||
Code: code,
|
||||
Message: err.Error(),
|
||||
Data: nil,
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
func ErrorStrResp(c *gin.Context, str string, code int) {
|
||||
log.Error(str)
|
||||
c.JSON(200, Resp{
|
||||
Code: code,
|
||||
Message: str,
|
||||
Data: nil,
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
func SuccessResp(c *gin.Context, data ...interface{}) {
|
||||
if len(data) == 0 {
|
||||
c.JSON(200, Resp{
|
||||
Code: 200,
|
||||
Message: "success",
|
||||
Data: nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(200, Resp{
|
||||
Code: 200,
|
||||
Message: "success",
|
||||
Data: data[0],
|
||||
})
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
package controllers
|
||||
|
||||
import (
|
||||
"github.com/Xhofe/go-cache"
|
||||
"github.com/alist-org/alist/v3/internal/server/common"
|
||||
"github.com/alist-org/alist/v3/internal/store"
|
||||
"github.com/gin-gonic/gin"
|
||||
"time"
|
||||
)
|
||||
|
||||
var loginCache = cache.NewMemCache[int]()
|
||||
var (
|
||||
defaultDuration = time.Minute * 5
|
||||
defaultTimes = 5
|
||||
)
|
||||
|
||||
type LoginReq struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
func Login(c *gin.Context) {
|
||||
// check count of login
|
||||
ip := c.ClientIP()
|
||||
count, ok := loginCache.Get(ip)
|
||||
if ok && count > defaultTimes {
|
||||
common.ErrorStrResp(c, "Too many unsuccessful sign-in attempts have been made using an incorrect password. Try again later.", 403)
|
||||
loginCache.Expire(ip, defaultDuration)
|
||||
return
|
||||
}
|
||||
// check username
|
||||
var req LoginReq
|
||||
if err := c.ShouldBind(&req); err != nil {
|
||||
common.ErrorResp(c, err, 400)
|
||||
return
|
||||
}
|
||||
user, err := store.GetUserByName(req.Username)
|
||||
if err != nil {
|
||||
common.ErrorResp(c, err, 400)
|
||||
return
|
||||
}
|
||||
// validate password
|
||||
if err := user.ValidatePassword(req.Password); err != nil {
|
||||
common.ErrorResp(c, err, 400)
|
||||
loginCache.Set(ip, count+1)
|
||||
return
|
||||
}
|
||||
// generate token
|
||||
token, err := common.GenerateToken(user.Username)
|
||||
if err != nil {
|
||||
common.ErrorResp(c, err, 400)
|
||||
return
|
||||
}
|
||||
common.SuccessResp(c, gin.H{"token": token})
|
||||
loginCache.Del(ip)
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
package controllers
|
|
@ -0,0 +1,25 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"github.com/alist-org/alist/v3/internal/server/common"
|
||||
"github.com/alist-org/alist/v3/internal/store"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func AuthAdmin(c *gin.Context) {
|
||||
token := c.GetHeader("Authorization")
|
||||
userClaims, err := common.ParseToken(token)
|
||||
if err != nil {
|
||||
common.ErrorResp(c, err, 401)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
user, err := store.GetUserByName(userClaims.Username)
|
||||
if err != nil {
|
||||
common.ErrorResp(c, err, 401)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Set("user", user)
|
||||
c.Next()
|
||||
}
|
|
@ -1,12 +1,19 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"github.com/alist-org/alist/v3/internal/conf"
|
||||
"github.com/alist-org/alist/v3/internal/server/common"
|
||||
"github.com/alist-org/alist/v3/internal/server/controllers"
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func Init(r *gin.Engine) {
|
||||
common.SecretKey = []byte(conf.Conf.JwtSecret)
|
||||
Cors(r)
|
||||
|
||||
api := r.Group("/api")
|
||||
api.POST("/user/login", controllers.Login)
|
||||
}
|
||||
|
||||
func Cors(r *gin.Engine) {
|
||||
|
|
|
@ -18,17 +18,17 @@ func ExistGuest() bool {
|
|||
return db.Take(&model.User{Role: model.GUEST}).Error != nil
|
||||
}
|
||||
|
||||
func GetUserByName(name string) (*model.User, error) {
|
||||
user, ok := userCache.Get(name)
|
||||
func GetUserByName(username string) (*model.User, error) {
|
||||
user, ok := userCache.Get(username)
|
||||
if ok {
|
||||
return user, nil
|
||||
}
|
||||
user, err, _ := userG.Do(name, func() (*model.User, error) {
|
||||
user := model.User{Name: name}
|
||||
user, err, _ := userG.Do(username, func() (*model.User, error) {
|
||||
user := model.User{Username: username}
|
||||
if err := db.Where(user).First(&user).Error; err != nil {
|
||||
return nil, errors.Wrapf(err, "failed select user")
|
||||
return nil, errors.Wrapf(err, "failed find user")
|
||||
}
|
||||
userCache.Set(name, &user)
|
||||
userCache.Set(username, &user)
|
||||
return &user, nil
|
||||
})
|
||||
return user, err
|
||||
|
@ -51,7 +51,7 @@ func UpdateUser(u *model.User) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userCache.Del(old.Name)
|
||||
userCache.Del(old.Username)
|
||||
return errors.WithStack(db.Save(u).Error)
|
||||
}
|
||||
|
||||
|
@ -73,6 +73,6 @@ func DeleteUserById(id uint) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userCache.Del(old.Name)
|
||||
userCache.Del(old.Username)
|
||||
return errors.WithStack(db.Delete(&model.User{}, id).Error)
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package utils
|
||||
package random
|
||||
|
||||
import (
|
||||
"math/rand"
|
Loading…
Reference in New Issue