mirror of https://github.com/Xhofe/alist
feat: user jwt login
parent
306b90399c
commit
c5295f4d72
|
@ -0,0 +1,39 @@
|
||||||
|
package bootstrap
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/alist-org/alist/v3/cmd/args"
|
||||||
|
"github.com/alist-org/alist/v3/internal/conf"
|
||||||
|
"github.com/alist-org/alist/v3/internal/store"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
"gorm.io/gorm/schema"
|
||||||
|
stdlog "log"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func InitDB() {
|
||||||
|
newLogger := logger.New(
|
||||||
|
stdlog.New(log.StandardLogger().Out, "\r\n", stdlog.LstdFlags),
|
||||||
|
logger.Config{
|
||||||
|
SlowThreshold: time.Second,
|
||||||
|
LogLevel: logger.Silent,
|
||||||
|
IgnoreRecordNotFoundError: true,
|
||||||
|
Colorful: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
gormConfig := &gorm.Config{
|
||||||
|
NamingStrategy: schema.NamingStrategy{
|
||||||
|
TablePrefix: conf.Conf.Database.TablePrefix,
|
||||||
|
},
|
||||||
|
Logger: newLogger,
|
||||||
|
}
|
||||||
|
if args.Dev {
|
||||||
|
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), gormConfig)
|
||||||
|
if err != nil {
|
||||||
|
panic("failed to connect database")
|
||||||
|
}
|
||||||
|
store.Init(db)
|
||||||
|
}
|
||||||
|
}
|
|
@ -21,7 +21,7 @@ func init() {
|
||||||
|
|
||||||
func Log() {
|
func Log() {
|
||||||
log.SetOutput(logrus.StandardLogger().Out)
|
log.SetOutput(logrus.StandardLogger().Out)
|
||||||
if args.Debug {
|
if args.Debug || args.Dev {
|
||||||
logrus.SetLevel(logrus.DebugLevel)
|
logrus.SetLevel(logrus.DebugLevel)
|
||||||
logrus.SetReportCaller(true)
|
logrus.SetReportCaller(true)
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ func init() {
|
||||||
flag.BoolVar(&args.Version, "version", false, "print version info")
|
flag.BoolVar(&args.Version, "version", false, "print version info")
|
||||||
flag.BoolVar(&args.Password, "password", false, "print current password")
|
flag.BoolVar(&args.Password, "password", false, "print current password")
|
||||||
flag.BoolVar(&args.NoPrefix, "no-prefix", false, "disable env prefix")
|
flag.BoolVar(&args.NoPrefix, "no-prefix", false, "disable env prefix")
|
||||||
|
flag.BoolVar(&args.Dev, "dev", false, "start with dev mode")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -31,6 +32,7 @@ func Init() {
|
||||||
}
|
}
|
||||||
bootstrap.InitConfig()
|
bootstrap.InitConfig()
|
||||||
bootstrap.Log()
|
bootstrap.Log()
|
||||||
|
bootstrap.InitDB()
|
||||||
}
|
}
|
||||||
func main() {
|
func main() {
|
||||||
Init()
|
Init()
|
||||||
|
|
|
@ -6,4 +6,5 @@ var (
|
||||||
Version bool
|
Version bool
|
||||||
Password bool
|
Password bool
|
||||||
NoPrefix bool
|
NoPrefix bool
|
||||||
|
Dev bool
|
||||||
)
|
)
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -21,6 +21,7 @@ require (
|
||||||
github.com/go-playground/universal-translator v0.18.0 // indirect
|
github.com/go-playground/universal-translator v0.18.0 // indirect
|
||||||
github.com/go-playground/validator/v10 v10.11.0 // indirect
|
github.com/go-playground/validator/v10 v10.11.0 // indirect
|
||||||
github.com/goccy/go-json v0.9.7 // indirect
|
github.com/goccy/go-json v0.9.7 // indirect
|
||||||
|
github.com/golang-jwt/jwt/v4 v4.4.2 // indirect
|
||||||
github.com/google/uuid v1.3.0 // indirect
|
github.com/google/uuid v1.3.0 // indirect
|
||||||
github.com/gorilla/websocket v1.5.0 // indirect
|
github.com/gorilla/websocket v1.5.0 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -25,6 +25,8 @@ github.com/go-playground/validator/v10 v10.11.0 h1:0W+xRM511GY47Yy3bZUbJVitCNg2B
|
||||||
github.com/go-playground/validator/v10 v10.11.0/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU=
|
github.com/go-playground/validator/v10 v10.11.0/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU=
|
||||||
github.com/goccy/go-json v0.9.7 h1:IcB+Aqpx/iMHu5Yooh7jEzJk1JZ7Pjtmys2ukPr7EeM=
|
github.com/goccy/go-json v0.9.7 h1:IcB+Aqpx/iMHu5Yooh7jEzJk1JZ7Pjtmys2ukPr7EeM=
|
||||||
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||||
|
github.com/golang-jwt/jwt/v4 v4.4.2 h1:rcc4lwaZgFMCZ5jxF9ABolDcIHdBytAFgqFPbSJQAYs=
|
||||||
|
github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
package conf
|
package conf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/alist-org/alist/v3/pkg/utils/random"
|
||||||
|
)
|
||||||
|
|
||||||
type Database struct {
|
type Database struct {
|
||||||
Type string `json:"type" env:"DB_TYPE"`
|
Type string `json:"type" env:"DB_TYPE"`
|
||||||
Host string `json:"host" env:"DB_HOST"`
|
Host string `json:"host" env:"DB_HOST"`
|
||||||
|
@ -30,6 +34,7 @@ type Config struct {
|
||||||
Force bool `json:"force"`
|
Force bool `json:"force"`
|
||||||
Address string `json:"address" env:"ADDR"`
|
Address string `json:"address" env:"ADDR"`
|
||||||
Port int `json:"port" env:"PORT"`
|
Port int `json:"port" env:"PORT"`
|
||||||
|
JwtSecret string `json:"jwt_secret" env:"JWT_SECRET"`
|
||||||
CaCheExpiration int `json:"cache_expiration" env:"CACHE_EXPIRATION"`
|
CaCheExpiration int `json:"cache_expiration" env:"CACHE_EXPIRATION"`
|
||||||
Assets string `json:"assets" env:"ASSETS"`
|
Assets string `json:"assets" env:"ASSETS"`
|
||||||
Database Database `json:"database"`
|
Database Database `json:"database"`
|
||||||
|
@ -40,10 +45,11 @@ type Config struct {
|
||||||
|
|
||||||
func DefaultConfig() *Config {
|
func DefaultConfig() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
Address: "0.0.0.0",
|
Address: "0.0.0.0",
|
||||||
Port: 5244,
|
Port: 5244,
|
||||||
Assets: "https://npm.elemecdn.com/alist-web@$version/dist",
|
JwtSecret: random.RandomStr(16),
|
||||||
TempDir: "data/temp",
|
Assets: "https://npm.elemecdn.com/alist-web@$version/dist",
|
||||||
|
TempDir: "data/temp",
|
||||||
Database: Database{
|
Database: Database{
|
||||||
Type: "sqlite3",
|
Type: "sqlite3",
|
||||||
Port: 0,
|
Port: 0,
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
package model
|
package model
|
||||||
|
|
||||||
|
import "github.com/pkg/errors"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
GENERAL = iota
|
GENERAL = iota
|
||||||
GUEST // only one exists
|
GUEST // only one exists
|
||||||
|
@ -7,12 +9,12 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID uint `json:"id" gorm:"primaryKey"` // unique key
|
ID uint `json:"id" gorm:"primaryKey"` // unique key
|
||||||
Name string `json:"name" gorm:"unique"` // username
|
Username string `json:"username" gorm:"unique"` // username
|
||||||
Password string `json:"password"` // password
|
Password string `json:"password"` // password
|
||||||
BasePath string `json:"base_path"` // base path
|
BasePath string `json:"base_path"` // base path
|
||||||
ReadOnly bool `json:"read_only"` // allow upload
|
ReadOnly bool `json:"read_only"` // allow upload
|
||||||
Role int `json:"role"` // user's role
|
Role int `json:"role"` // user's role
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u User) IsGuest() bool {
|
func (u User) IsGuest() bool {
|
||||||
|
@ -22,3 +24,13 @@ func (u User) IsGuest() bool {
|
||||||
func (u User) IsAdmin() bool {
|
func (u User) IsAdmin() bool {
|
||||||
return u.Role == ADMIN
|
return u.Role == ADMIN
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u User) ValidatePassword(password string) error {
|
||||||
|
if password == "" {
|
||||||
|
return errors.New("password is empty")
|
||||||
|
}
|
||||||
|
if u.Password != password {
|
||||||
|
return errors.New("password is incorrect")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var SecretKey []byte
|
||||||
|
|
||||||
|
type UserClaims struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateToken(username string) (tokenString string, err error) {
|
||||||
|
claim := UserClaims{
|
||||||
|
Username: username,
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(12 * time.Hour)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
|
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||||
|
}}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claim)
|
||||||
|
tokenString, err = token.SignedString(SecretKey)
|
||||||
|
return tokenString, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseToken(tokenString string) (*UserClaims, error) {
|
||||||
|
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
return SecretKey, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
if ve, ok := err.(*jwt.ValidationError); ok {
|
||||||
|
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
|
||||||
|
return nil, errors.New("that's not even a token")
|
||||||
|
} else if ve.Errors&jwt.ValidationErrorExpired != 0 {
|
||||||
|
return nil, errors.New("token is expired")
|
||||||
|
} else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 {
|
||||||
|
return nil, errors.New("token not active yet")
|
||||||
|
} else {
|
||||||
|
return nil, errors.New("couldn't handle this token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if claims, ok := token.Claims.(*UserClaims); ok && token.Valid {
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("couldn't handle this token")
|
||||||
|
}
|
|
@ -0,0 +1,48 @@
|
||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Resp struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data interface{} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrorResp(c *gin.Context, err error, code int) {
|
||||||
|
log.Error(err.Error())
|
||||||
|
c.JSON(200, Resp{
|
||||||
|
Code: code,
|
||||||
|
Message: err.Error(),
|
||||||
|
Data: nil,
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrorStrResp(c *gin.Context, str string, code int) {
|
||||||
|
log.Error(str)
|
||||||
|
c.JSON(200, Resp{
|
||||||
|
Code: code,
|
||||||
|
Message: str,
|
||||||
|
Data: nil,
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
|
||||||
|
func SuccessResp(c *gin.Context, data ...interface{}) {
|
||||||
|
if len(data) == 0 {
|
||||||
|
c.JSON(200, Resp{
|
||||||
|
Code: 200,
|
||||||
|
Message: "success",
|
||||||
|
Data: nil,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(200, Resp{
|
||||||
|
Code: 200,
|
||||||
|
Message: "success",
|
||||||
|
Data: data[0],
|
||||||
|
})
|
||||||
|
}
|
|
@ -0,0 +1,56 @@
|
||||||
|
package controllers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/Xhofe/go-cache"
|
||||||
|
"github.com/alist-org/alist/v3/internal/server/common"
|
||||||
|
"github.com/alist-org/alist/v3/internal/store"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var loginCache = cache.NewMemCache[int]()
|
||||||
|
var (
|
||||||
|
defaultDuration = time.Minute * 5
|
||||||
|
defaultTimes = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
type LoginReq struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func Login(c *gin.Context) {
|
||||||
|
// check count of login
|
||||||
|
ip := c.ClientIP()
|
||||||
|
count, ok := loginCache.Get(ip)
|
||||||
|
if ok && count > defaultTimes {
|
||||||
|
common.ErrorStrResp(c, "Too many unsuccessful sign-in attempts have been made using an incorrect password. Try again later.", 403)
|
||||||
|
loginCache.Expire(ip, defaultDuration)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// check username
|
||||||
|
var req LoginReq
|
||||||
|
if err := c.ShouldBind(&req); err != nil {
|
||||||
|
common.ErrorResp(c, err, 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
user, err := store.GetUserByName(req.Username)
|
||||||
|
if err != nil {
|
||||||
|
common.ErrorResp(c, err, 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// validate password
|
||||||
|
if err := user.ValidatePassword(req.Password); err != nil {
|
||||||
|
common.ErrorResp(c, err, 400)
|
||||||
|
loginCache.Set(ip, count+1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// generate token
|
||||||
|
token, err := common.GenerateToken(user.Username)
|
||||||
|
if err != nil {
|
||||||
|
common.ErrorResp(c, err, 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.SuccessResp(c, gin.H{"token": token})
|
||||||
|
loginCache.Del(ip)
|
||||||
|
}
|
|
@ -1 +0,0 @@
|
||||||
package controllers
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
package middlewares
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/alist-org/alist/v3/internal/server/common"
|
||||||
|
"github.com/alist-org/alist/v3/internal/store"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func AuthAdmin(c *gin.Context) {
|
||||||
|
token := c.GetHeader("Authorization")
|
||||||
|
userClaims, err := common.ParseToken(token)
|
||||||
|
if err != nil {
|
||||||
|
common.ErrorResp(c, err, 401)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
user, err := store.GetUserByName(userClaims.Username)
|
||||||
|
if err != nil {
|
||||||
|
common.ErrorResp(c, err, 401)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set("user", user)
|
||||||
|
c.Next()
|
||||||
|
}
|
|
@ -1,12 +1,19 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/alist-org/alist/v3/internal/conf"
|
||||||
|
"github.com/alist-org/alist/v3/internal/server/common"
|
||||||
|
"github.com/alist-org/alist/v3/internal/server/controllers"
|
||||||
"github.com/gin-contrib/cors"
|
"github.com/gin-contrib/cors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Init(r *gin.Engine) {
|
func Init(r *gin.Engine) {
|
||||||
|
common.SecretKey = []byte(conf.Conf.JwtSecret)
|
||||||
Cors(r)
|
Cors(r)
|
||||||
|
|
||||||
|
api := r.Group("/api")
|
||||||
|
api.POST("/user/login", controllers.Login)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Cors(r *gin.Engine) {
|
func Cors(r *gin.Engine) {
|
||||||
|
|
|
@ -18,17 +18,17 @@ func ExistGuest() bool {
|
||||||
return db.Take(&model.User{Role: model.GUEST}).Error != nil
|
return db.Take(&model.User{Role: model.GUEST}).Error != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserByName(name string) (*model.User, error) {
|
func GetUserByName(username string) (*model.User, error) {
|
||||||
user, ok := userCache.Get(name)
|
user, ok := userCache.Get(username)
|
||||||
if ok {
|
if ok {
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
user, err, _ := userG.Do(name, func() (*model.User, error) {
|
user, err, _ := userG.Do(username, func() (*model.User, error) {
|
||||||
user := model.User{Name: name}
|
user := model.User{Username: username}
|
||||||
if err := db.Where(user).First(&user).Error; err != nil {
|
if err := db.Where(user).First(&user).Error; err != nil {
|
||||||
return nil, errors.Wrapf(err, "failed select user")
|
return nil, errors.Wrapf(err, "failed find user")
|
||||||
}
|
}
|
||||||
userCache.Set(name, &user)
|
userCache.Set(username, &user)
|
||||||
return &user, nil
|
return &user, nil
|
||||||
})
|
})
|
||||||
return user, err
|
return user, err
|
||||||
|
@ -51,7 +51,7 @@ func UpdateUser(u *model.User) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
userCache.Del(old.Name)
|
userCache.Del(old.Username)
|
||||||
return errors.WithStack(db.Save(u).Error)
|
return errors.WithStack(db.Save(u).Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,6 +73,6 @@ func DeleteUserById(id uint) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
userCache.Del(old.Name)
|
userCache.Del(old.Username)
|
||||||
return errors.WithStack(db.Delete(&model.User{}, id).Error)
|
return errors.WithStack(db.Delete(&model.User{}, id).Error)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package utils
|
package random
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math/rand"
|
"math/rand"
|
Loading…
Reference in New Issue