mirror of https://github.com/Xhofe/alist
feat(auth): Added device session management
- Added the `handleSession` function to manage user device sessions and verify client identity - Updated `auth.go` to call `handleSession` for device handling when a user logs in - Added the `Session` model to database migrations - Added `device.go` and `session.go` files to handle device session logic - Updated `settings.go` to add device-related configuration items, such as the maximum number of devices, device eviction policy, and session TTLpull/9286/head
parent
a9fcd51bc4
commit
bdbaa85213
|
@ -165,6 +165,9 @@ func InitialSettings() []model.SettingItem {
|
||||||
{Key: conf.ForwardDirectLinkParams, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL},
|
{Key: conf.ForwardDirectLinkParams, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL},
|
||||||
{Key: conf.IgnoreDirectLinkParams, Value: "sign,alist_ts", Type: conf.TypeString, Group: model.GLOBAL},
|
{Key: conf.IgnoreDirectLinkParams, Value: "sign,alist_ts", Type: conf.TypeString, Group: model.GLOBAL},
|
||||||
{Key: conf.WebauthnLoginEnabled, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL, Flag: model.PUBLIC},
|
{Key: conf.WebauthnLoginEnabled, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL, Flag: model.PUBLIC},
|
||||||
|
{Key: conf.MaxDevices, Value: "0", Type: conf.TypeNumber, Group: model.GLOBAL},
|
||||||
|
{Key: conf.DeviceEvictPolicy, Value: "deny", Type: conf.TypeSelect, Options: "deny,evict_oldest", Group: model.GLOBAL},
|
||||||
|
{Key: conf.DeviceSessionTTL, Value: "86400", Type: conf.TypeNumber, Group: model.GLOBAL},
|
||||||
|
|
||||||
// single settings
|
// single settings
|
||||||
{Key: conf.Token, Value: token, Type: conf.TypeString, Group: model.SINGLE, Flag: model.PRIVATE},
|
{Key: conf.Token, Value: token, Type: conf.TypeString, Group: model.SINGLE, Flag: model.PRIVATE},
|
||||||
|
|
|
@ -48,6 +48,9 @@ const (
|
||||||
ForwardDirectLinkParams = "forward_direct_link_params"
|
ForwardDirectLinkParams = "forward_direct_link_params"
|
||||||
IgnoreDirectLinkParams = "ignore_direct_link_params"
|
IgnoreDirectLinkParams = "ignore_direct_link_params"
|
||||||
WebauthnLoginEnabled = "webauthn_login_enabled"
|
WebauthnLoginEnabled = "webauthn_login_enabled"
|
||||||
|
MaxDevices = "max_devices"
|
||||||
|
DeviceEvictPolicy = "device_evict_policy"
|
||||||
|
DeviceSessionTTL = "device_session_ttl"
|
||||||
|
|
||||||
// index
|
// index
|
||||||
SearchIndex = "search_index"
|
SearchIndex = "search_index"
|
||||||
|
|
|
@ -12,7 +12,7 @@ var db *gorm.DB
|
||||||
|
|
||||||
func Init(d *gorm.DB) {
|
func Init(d *gorm.DB) {
|
||||||
db = d
|
db = d
|
||||||
err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.Role), new(model.Label), new(model.LabelFileBinding), new(model.ObjFile))
|
err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.Role), new(model.Label), new(model.LabelFileBinding), new(model.ObjFile), new(model.Session))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed migrate database: %s", err.Error())
|
log.Fatalf("failed migrate database: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/alist-org/alist/v3/internal/model"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetSession(userID uint, deviceKey string) (*model.Session, error) {
|
||||||
|
s := model.Session{UserID: userID, DeviceKey: deviceKey}
|
||||||
|
if err := db.Where(&s).First(&s).Error; err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed find session")
|
||||||
|
}
|
||||||
|
return &s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateSession(s *model.Session) error {
|
||||||
|
return errors.WithStack(db.Create(s).Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpsertSession(s *model.Session) error {
|
||||||
|
return errors.WithStack(db.Clauses(clause.OnConflict{UpdateAll: true}).Create(s).Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteSession(userID uint, deviceKey string) error {
|
||||||
|
return errors.WithStack(db.Where("user_id = ? AND device_key = ?", userID, deviceKey).Delete(&model.Session{}).Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CountSessionsByUser(userID uint) (int64, error) {
|
||||||
|
var count int64
|
||||||
|
err := db.Model(&model.Session{}).Where("user_id = ?", userID).Count(&count).Error
|
||||||
|
return count, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteSessionsBefore(ts int64) error {
|
||||||
|
return errors.WithStack(db.Where("last_active < ?", ts).Delete(&model.Session{}).Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetOldestSession(userID uint) (*model.Session, error) {
|
||||||
|
var s model.Session
|
||||||
|
if err := db.Where("user_id = ?", userID).Order("last_active ASC").First(&s).Error; err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed get oldest session")
|
||||||
|
}
|
||||||
|
return &s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdateSessionLastActive(userID uint, deviceKey string, lastActive int64) error {
|
||||||
|
return errors.WithStack(db.Model(&model.Session{}).Where("user_id = ? AND device_key = ?", userID, deviceKey).Update("last_active", lastActive).Error)
|
||||||
|
}
|
|
@ -0,0 +1,59 @@
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/alist-org/alist/v3/internal/conf"
|
||||||
|
"github.com/alist-org/alist/v3/internal/db"
|
||||||
|
"github.com/alist-org/alist/v3/internal/errs"
|
||||||
|
"github.com/alist-org/alist/v3/internal/model"
|
||||||
|
"github.com/alist-org/alist/v3/internal/setting"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Handle verifies device sessions for a user and upserts current session.
|
||||||
|
func Handle(userID uint, deviceKey string) error {
|
||||||
|
ttl := setting.GetInt(conf.DeviceSessionTTL, 86400)
|
||||||
|
if ttl > 0 {
|
||||||
|
_ = db.DeleteSessionsBefore(time.Now().Unix() - int64(ttl))
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().Unix()
|
||||||
|
sess, err := db.GetSession(userID, deviceKey)
|
||||||
|
if err == nil {
|
||||||
|
sess.LastActive = now
|
||||||
|
sess.Status = model.SessionActive
|
||||||
|
return db.UpsertSession(sess)
|
||||||
|
}
|
||||||
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
max := setting.GetInt(conf.MaxDevices, 0)
|
||||||
|
if max > 0 {
|
||||||
|
count, err := db.CountSessionsByUser(userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count >= int64(max) {
|
||||||
|
policy := setting.GetStr(conf.DeviceEvictPolicy, "deny")
|
||||||
|
if policy == "evict_oldest" {
|
||||||
|
oldest, err := db.GetOldestSession(userID)
|
||||||
|
if err == nil {
|
||||||
|
_ = db.DeleteSession(userID, oldest.DeviceKey)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return errors.WithStack(errs.TooManyDevices)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &model.Session{UserID: userID, DeviceKey: deviceKey, LastActive: now, Status: model.SessionActive}
|
||||||
|
return db.CreateSession(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh updates last_active for the session.
|
||||||
|
func Refresh(userID uint, deviceKey string) {
|
||||||
|
_ = db.UpdateSessionLastActive(userID, deviceKey, time.Now().Unix())
|
||||||
|
}
|
|
@ -0,0 +1,7 @@
|
||||||
|
package errs
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
TooManyDevices = errors.New("too many active devices")
|
||||||
|
)
|
|
@ -0,0 +1,14 @@
|
||||||
|
package model
|
||||||
|
|
||||||
|
// Session represents a device session of a user.
|
||||||
|
type Session struct {
|
||||||
|
UserID uint `json:"user_id" gorm:"index"`
|
||||||
|
DeviceKey string `json:"device_key" gorm:"primaryKey;size:64"`
|
||||||
|
LastActive int64 `json:"last_active"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
SessionActive = iota
|
||||||
|
SessionInactive
|
||||||
|
)
|
|
@ -5,9 +5,11 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/alist-org/alist/v3/internal/conf"
|
"github.com/alist-org/alist/v3/internal/conf"
|
||||||
|
"github.com/alist-org/alist/v3/internal/device"
|
||||||
"github.com/alist-org/alist/v3/internal/model"
|
"github.com/alist-org/alist/v3/internal/model"
|
||||||
"github.com/alist-org/alist/v3/internal/op"
|
"github.com/alist-org/alist/v3/internal/op"
|
||||||
"github.com/alist-org/alist/v3/internal/setting"
|
"github.com/alist-org/alist/v3/internal/setting"
|
||||||
|
"github.com/alist-org/alist/v3/pkg/utils"
|
||||||
"github.com/alist-org/alist/v3/server/common"
|
"github.com/alist-org/alist/v3/server/common"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
@ -24,7 +26,9 @@ func Auth(c *gin.Context) {
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Set("user", admin)
|
if !handleSession(c, admin) {
|
||||||
|
return
|
||||||
|
}
|
||||||
log.Debugf("use admin token: %+v", admin)
|
log.Debugf("use admin token: %+v", admin)
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
|
@ -50,7 +54,9 @@ func Auth(c *gin.Context) {
|
||||||
}
|
}
|
||||||
guest.RolesDetail = roles
|
guest.RolesDetail = roles
|
||||||
}
|
}
|
||||||
c.Set("user", guest)
|
if !handleSession(c, guest) {
|
||||||
|
return
|
||||||
|
}
|
||||||
log.Debugf("use empty token: %+v", guest)
|
log.Debugf("use empty token: %+v", guest)
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
|
@ -87,11 +93,29 @@ func Auth(c *gin.Context) {
|
||||||
}
|
}
|
||||||
user.RolesDetail = roles
|
user.RolesDetail = roles
|
||||||
}
|
}
|
||||||
c.Set("user", user)
|
if !handleSession(c, user) {
|
||||||
|
return
|
||||||
|
}
|
||||||
log.Debugf("use login token: %+v", user)
|
log.Debugf("use login token: %+v", user)
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func handleSession(c *gin.Context, user *model.User) bool {
|
||||||
|
clientID := c.GetHeader("Client-Id")
|
||||||
|
if clientID == "" {
|
||||||
|
clientID = c.Query("client_id")
|
||||||
|
}
|
||||||
|
key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s-%s-%s", user.ID, c.Request.UserAgent(), c.ClientIP(), clientID))
|
||||||
|
if err := device.Handle(user.ID, key); err != nil {
|
||||||
|
common.ErrorResp(c, err, 403)
|
||||||
|
c.Abort()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
c.Set("device_key", key)
|
||||||
|
c.Set("user", user)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func Authn(c *gin.Context) {
|
func Authn(c *gin.Context) {
|
||||||
token := c.GetHeader("Authorization")
|
token := c.GetHeader("Authorization")
|
||||||
if subtle.ConstantTimeCompare([]byte(token), []byte(setting.GetStr(conf.Token))) == 1 {
|
if subtle.ConstantTimeCompare([]byte(token), []byte(setting.GetStr(conf.Token))) == 1 {
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
package middlewares
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/alist-org/alist/v3/internal/device"
|
||||||
|
"github.com/alist-org/alist/v3/internal/model"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SessionRefresh updates session's last_active after successful requests.
|
||||||
|
func SessionRefresh(c *gin.Context) {
|
||||||
|
c.Next()
|
||||||
|
if c.Writer.Status() >= 400 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userVal, uok := c.Get("user")
|
||||||
|
keyVal, kok := c.Get("device_key")
|
||||||
|
if uok && kok {
|
||||||
|
user := userVal.(*model.User)
|
||||||
|
device.Refresh(user.ID, keyVal.(string))
|
||||||
|
}
|
||||||
|
}
|
|
@ -22,6 +22,7 @@ func Init(e *gin.Engine) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
Cors(e)
|
Cors(e)
|
||||||
|
e.Use(middlewares.SessionRefresh)
|
||||||
g := e.Group(conf.URL.Path)
|
g := e.Group(conf.URL.Path)
|
||||||
if conf.Conf.Scheme.HttpPort != -1 && conf.Conf.Scheme.HttpsPort != -1 && conf.Conf.Scheme.ForceHttps {
|
if conf.Conf.Scheme.HttpPort != -1 && conf.Conf.Scheme.HttpsPort != -1 && conf.Conf.Scheme.ForceHttps {
|
||||||
e.Use(middlewares.ForceHttps)
|
e.Use(middlewares.ForceHttps)
|
||||||
|
|
|
@ -3,6 +3,7 @@ package server
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path"
|
"path"
|
||||||
|
@ -12,9 +13,11 @@ import (
|
||||||
"github.com/alist-org/alist/v3/server/middlewares"
|
"github.com/alist-org/alist/v3/server/middlewares"
|
||||||
|
|
||||||
"github.com/alist-org/alist/v3/internal/conf"
|
"github.com/alist-org/alist/v3/internal/conf"
|
||||||
|
"github.com/alist-org/alist/v3/internal/device"
|
||||||
"github.com/alist-org/alist/v3/internal/model"
|
"github.com/alist-org/alist/v3/internal/model"
|
||||||
"github.com/alist-org/alist/v3/internal/op"
|
"github.com/alist-org/alist/v3/internal/op"
|
||||||
"github.com/alist-org/alist/v3/internal/setting"
|
"github.com/alist-org/alist/v3/internal/setting"
|
||||||
|
"github.com/alist-org/alist/v3/pkg/utils"
|
||||||
"github.com/alist-org/alist/v3/server/common"
|
"github.com/alist-org/alist/v3/server/common"
|
||||||
"github.com/alist-org/alist/v3/server/webdav"
|
"github.com/alist-org/alist/v3/server/webdav"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
@ -69,6 +72,13 @@ func WebDAVAuth(c *gin.Context) {
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", admin.ID, c.ClientIP()))
|
||||||
|
if err := device.Handle(admin.ID, key); err != nil {
|
||||||
|
c.Status(http.StatusForbidden)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set("device_key", key)
|
||||||
c.Set("user", admin)
|
c.Set("user", admin)
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
|
@ -146,6 +156,13 @@ func WebDAVAuth(c *gin.Context) {
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", user.ID, c.ClientIP()))
|
||||||
|
if err := device.Handle(user.ID, key); err != nil {
|
||||||
|
c.Status(http.StatusForbidden)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set("device_key", key)
|
||||||
c.Set("user", user)
|
c.Set("user", user)
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue