From c64f899a636b66d12056c6ae3fcab6c21b0cb146 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8D=83=E7=9F=B3?= Date: Mon, 25 Aug 2025 19:46:38 +0800 Subject: [PATCH] feat: implement session management (#9286) * feat(auth): Added device session management - Added the `handleSession` function to manage user device sessions and verify client identity - Updated `auth.go` to call `handleSession` for device handling when a user logs in - Added the `Session` model to database migrations - Added `device.go` and `session.go` files to handle device session logic - Updated `settings.go` to add device-related configuration items, such as the maximum number of devices, device eviction policy, and session TTL * feat(session): Adds session management features - Added `SessionInactive` error type in `device.go` - Added session-related APIs in `router.go` to support listing and evicting sessions - Added `ListSessionsByUser`, `ListSessions`, and `MarkInactive` methods in `session.go` - Returns an appropriate error when the session state is `SessionInactive` * feat(auth): Marks the device session as invalid. - Import the `session` package into the `auth` module to handle device session status. - Add a check in the login logic. If `device_key` is obtained, call `session.MarkInactive` to mark the device session as invalid. - Store the invalid status in the context variable `session_inactive` for subsequent middleware checks. - Add a check in the session refresh logic to abort the process if the current session has been marked invalid. * feat(auth, session): Added device information processing and session management changes - Updated device handling logic in `auth.go` to pass user agent and IP information - Adjusted database queries in `session.go` to optimize session query fields and add `user_agent` and `ip` fields - Modified the `Handle` method to add `ua` and `ip` parameters to store the user agent and IP address - Added the `SessionResp` structure to return a session response containing `user_agent` and `ip` - Updated the `/admin/user/create` and `/webdav` endpoints to pass the user agent and IP address to the device handler --- internal/bootstrap/data/setting.go | 3 + internal/conf/const.go | 3 + internal/db/db.go | 2 +- internal/db/session.go | 65 +++++++++++++++++++ internal/device/session.go | 67 +++++++++++++++++++ internal/errs/device.go | 8 +++ internal/model/session.go | 16 +++++ internal/session/session.go | 8 +++ pkg/utils/mask.go | 30 +++++++++ server/handles/auth.go | 8 +++ server/handles/session.go | 92 +++++++++++++++++++++++++++ server/middlewares/auth.go | 30 ++++++++- server/middlewares/session_refresh.go | 26 ++++++++ server/router.go | 7 ++ server/webdav.go | 17 +++++ 15 files changed, 378 insertions(+), 4 deletions(-) create mode 100644 internal/db/session.go create mode 100644 internal/device/session.go create mode 100644 internal/errs/device.go create mode 100644 internal/model/session.go create mode 100644 internal/session/session.go create mode 100644 pkg/utils/mask.go create mode 100644 server/handles/session.go create mode 100644 server/middlewares/session_refresh.go diff --git a/internal/bootstrap/data/setting.go b/internal/bootstrap/data/setting.go index 17a63af2..bbb633e3 100644 --- a/internal/bootstrap/data/setting.go +++ b/internal/bootstrap/data/setting.go @@ -165,6 +165,9 @@ func InitialSettings() []model.SettingItem { {Key: conf.ForwardDirectLinkParams, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL}, {Key: conf.IgnoreDirectLinkParams, Value: "sign,alist_ts", Type: conf.TypeString, Group: model.GLOBAL}, {Key: conf.WebauthnLoginEnabled, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL, Flag: model.PUBLIC}, + {Key: conf.MaxDevices, Value: "0", Type: conf.TypeNumber, Group: model.GLOBAL}, + {Key: conf.DeviceEvictPolicy, Value: "deny", Type: conf.TypeSelect, Options: "deny,evict_oldest", Group: model.GLOBAL}, + {Key: conf.DeviceSessionTTL, Value: "86400", Type: conf.TypeNumber, Group: model.GLOBAL}, // single settings {Key: conf.Token, Value: token, Type: conf.TypeString, Group: model.SINGLE, Flag: model.PRIVATE}, diff --git a/internal/conf/const.go b/internal/conf/const.go index 0bf0cd67..1a558163 100644 --- a/internal/conf/const.go +++ b/internal/conf/const.go @@ -48,6 +48,9 @@ const ( ForwardDirectLinkParams = "forward_direct_link_params" IgnoreDirectLinkParams = "ignore_direct_link_params" WebauthnLoginEnabled = "webauthn_login_enabled" + MaxDevices = "max_devices" + DeviceEvictPolicy = "device_evict_policy" + DeviceSessionTTL = "device_session_ttl" // index SearchIndex = "search_index" diff --git a/internal/db/db.go b/internal/db/db.go index 0d8ab421..4577059d 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -12,7 +12,7 @@ var db *gorm.DB func Init(d *gorm.DB) { db = d - err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.Role), new(model.Label), new(model.LabelFileBinding), new(model.ObjFile)) + err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.Role), new(model.Label), new(model.LabelFileBinding), new(model.ObjFile), new(model.Session)) if err != nil { log.Fatalf("failed migrate database: %s", err.Error()) } diff --git a/internal/db/session.go b/internal/db/session.go new file mode 100644 index 00000000..8db9fa69 --- /dev/null +++ b/internal/db/session.go @@ -0,0 +1,65 @@ +package db + +import ( + "github.com/alist-org/alist/v3/internal/model" + "github.com/pkg/errors" + "gorm.io/gorm/clause" +) + +func GetSession(userID uint, deviceKey string) (*model.Session, error) { + s := model.Session{UserID: userID, DeviceKey: deviceKey} + if err := db.Select("user_id, device_key, last_active, status, user_agent, ip").Where(&s).First(&s).Error; err != nil { + return nil, errors.Wrap(err, "failed find session") + } + return &s, nil +} + +func CreateSession(s *model.Session) error { + return errors.WithStack(db.Create(s).Error) +} + +func UpsertSession(s *model.Session) error { + return errors.WithStack(db.Clauses(clause.OnConflict{UpdateAll: true}).Create(s).Error) +} + +func DeleteSession(userID uint, deviceKey string) error { + return errors.WithStack(db.Where("user_id = ? AND device_key = ?", userID, deviceKey).Delete(&model.Session{}).Error) +} + +func CountSessionsByUser(userID uint) (int64, error) { + var count int64 + err := db.Model(&model.Session{}).Where("user_id = ?", userID).Count(&count).Error + return count, errors.WithStack(err) +} + +func DeleteSessionsBefore(ts int64) error { + return errors.WithStack(db.Where("last_active < ?", ts).Delete(&model.Session{}).Error) +} + +func GetOldestSession(userID uint) (*model.Session, error) { + var s model.Session + if err := db.Where("user_id = ?", userID).Order("last_active ASC").First(&s).Error; err != nil { + return nil, errors.Wrap(err, "failed get oldest session") + } + return &s, nil +} + +func UpdateSessionLastActive(userID uint, deviceKey string, lastActive int64) error { + return errors.WithStack(db.Model(&model.Session{}).Where("user_id = ? AND device_key = ?", userID, deviceKey).Update("last_active", lastActive).Error) +} + +func ListSessionsByUser(userID uint) ([]model.Session, error) { + var sessions []model.Session + err := db.Select("user_id, device_key, last_active, status, user_agent, ip").Where("user_id = ? AND status = ?", userID, model.SessionActive).Find(&sessions).Error + return sessions, errors.WithStack(err) +} + +func ListSessions() ([]model.Session, error) { + var sessions []model.Session + err := db.Select("user_id, device_key, last_active, status, user_agent, ip").Where("status = ?", model.SessionActive).Find(&sessions).Error + return sessions, errors.WithStack(err) +} + +func MarkInactive(sessionID string) error { + return errors.WithStack(db.Model(&model.Session{}).Where("device_key = ?", sessionID).Update("status", model.SessionInactive).Error) +} diff --git a/internal/device/session.go b/internal/device/session.go new file mode 100644 index 00000000..d407c858 --- /dev/null +++ b/internal/device/session.go @@ -0,0 +1,67 @@ +package device + +import ( + "time" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +// Handle verifies device sessions for a user and upserts current session. +func Handle(userID uint, deviceKey, ua, ip string) error { + ttl := setting.GetInt(conf.DeviceSessionTTL, 86400) + if ttl > 0 { + _ = db.DeleteSessionsBefore(time.Now().Unix() - int64(ttl)) + } + + ip = utils.MaskIP(ip) + + now := time.Now().Unix() + sess, err := db.GetSession(userID, deviceKey) + if err == nil { + if sess.Status == model.SessionInactive { + return errors.WithStack(errs.SessionInactive) + } + sess.LastActive = now + sess.Status = model.SessionActive + sess.UserAgent = ua + sess.IP = ip + return db.UpsertSession(sess) + } + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + max := setting.GetInt(conf.MaxDevices, 0) + if max > 0 { + count, err := db.CountSessionsByUser(userID) + if err != nil { + return err + } + if count >= int64(max) { + policy := setting.GetStr(conf.DeviceEvictPolicy, "deny") + if policy == "evict_oldest" { + oldest, err := db.GetOldestSession(userID) + if err == nil { + _ = db.DeleteSession(userID, oldest.DeviceKey) + } + } else { + return errors.WithStack(errs.TooManyDevices) + } + } + } + + s := &model.Session{UserID: userID, DeviceKey: deviceKey, UserAgent: ua, IP: ip, LastActive: now, Status: model.SessionActive} + return db.CreateSession(s) +} + +// Refresh updates last_active for the session. +func Refresh(userID uint, deviceKey string) { + _ = db.UpdateSessionLastActive(userID, deviceKey, time.Now().Unix()) +} diff --git a/internal/errs/device.go b/internal/errs/device.go new file mode 100644 index 00000000..3b79298a --- /dev/null +++ b/internal/errs/device.go @@ -0,0 +1,8 @@ +package errs + +import "errors" + +var ( + TooManyDevices = errors.New("too many active devices") + SessionInactive = errors.New("session inactive") +) diff --git a/internal/model/session.go b/internal/model/session.go new file mode 100644 index 00000000..3cb6d0da --- /dev/null +++ b/internal/model/session.go @@ -0,0 +1,16 @@ +package model + +// Session represents a device session of a user. +type Session struct { + UserID uint `json:"user_id" gorm:"index"` + DeviceKey string `json:"device_key" gorm:"primaryKey;size:64"` + UserAgent string `json:"user_agent" gorm:"size:255"` + IP string `json:"ip" gorm:"size:64"` + LastActive int64 `json:"last_active"` + Status int `json:"status"` +} + +const ( + SessionActive = iota + SessionInactive +) diff --git a/internal/session/session.go b/internal/session/session.go new file mode 100644 index 00000000..47d1b701 --- /dev/null +++ b/internal/session/session.go @@ -0,0 +1,8 @@ +package session + +import "github.com/alist-org/alist/v3/internal/db" + +// MarkInactive marks the session with the given ID as inactive. +func MarkInactive(sessionID string) error { + return db.MarkInactive(sessionID) +} diff --git a/pkg/utils/mask.go b/pkg/utils/mask.go new file mode 100644 index 00000000..1513ad40 --- /dev/null +++ b/pkg/utils/mask.go @@ -0,0 +1,30 @@ +package utils + +import "strings" + +// MaskIP anonymizes middle segments of an IP address. +func MaskIP(ip string) string { + if ip == "" { + return "" + } + if strings.Contains(ip, ":") { + parts := strings.Split(ip, ":") + if len(parts) > 2 { + for i := 1; i < len(parts)-1; i++ { + if parts[i] != "" { + parts[i] = "*" + } + } + return strings.Join(parts, ":") + } + return ip + } + parts := strings.Split(ip, ".") + if len(parts) == 4 { + for i := 1; i < len(parts)-1; i++ { + parts[i] = "*" + } + return strings.Join(parts, ".") + } + return ip +} diff --git a/server/handles/auth.go b/server/handles/auth.go index 26447ddd..30714f65 100644 --- a/server/handles/auth.go +++ b/server/handles/auth.go @@ -12,6 +12,7 @@ import ( "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/session" "github.com/alist-org/alist/v3/internal/setting" "github.com/alist-org/alist/v3/server/common" "github.com/gin-gonic/gin" @@ -247,6 +248,13 @@ func Verify2FA(c *gin.Context) { } func LogOut(c *gin.Context) { + if keyVal, ok := c.Get("device_key"); ok { + if err := session.MarkInactive(keyVal.(string)); err != nil { + common.ErrorResp(c, err, 500) + return + } + c.Set("session_inactive", true) + } err := common.InvalidateToken(c.GetHeader("Authorization")) if err != nil { common.ErrorResp(c, err, 500) diff --git a/server/handles/session.go b/server/handles/session.go new file mode 100644 index 00000000..886be66a --- /dev/null +++ b/server/handles/session.go @@ -0,0 +1,92 @@ +package handles + +import ( + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" +) + +type SessionResp struct { + SessionID string `json:"session_id"` + UserID uint `json:"user_id,omitempty"` + LastActive int64 `json:"last_active"` + Status int `json:"status"` + UA string `json:"ua"` + IP string `json:"ip"` +} + +func ListMySessions(c *gin.Context) { + user := c.MustGet("user").(*model.User) + sessions, err := db.ListSessionsByUser(user.ID) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + resp := make([]SessionResp, len(sessions)) + for i, s := range sessions { + resp[i] = SessionResp{ + SessionID: s.DeviceKey, + LastActive: s.LastActive, + Status: s.Status, + UA: s.UserAgent, + IP: s.IP, + } + } + common.SuccessResp(c, resp) +} + +type EvictSessionReq struct { + SessionID string `json:"session_id"` +} + +func EvictMySession(c *gin.Context) { + var req EvictSessionReq + if err := c.ShouldBindJSON(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + user := c.MustGet("user").(*model.User) + if _, err := db.GetSession(user.ID, req.SessionID); err != nil { + common.ErrorResp(c, err, 400) + return + } + if err := db.MarkInactive(req.SessionID); err != nil { + common.ErrorResp(c, err, 500) + return + } + common.SuccessResp(c) +} + +func ListSessions(c *gin.Context) { + sessions, err := db.ListSessions() + if err != nil { + common.ErrorResp(c, err, 500) + return + } + resp := make([]SessionResp, len(sessions)) + for i, s := range sessions { + resp[i] = SessionResp{ + SessionID: s.DeviceKey, + UserID: s.UserID, + LastActive: s.LastActive, + Status: s.Status, + UA: s.UserAgent, + IP: s.IP, + } + } + common.SuccessResp(c, resp) +} + +func EvictSession(c *gin.Context) { + var req EvictSessionReq + if err := c.ShouldBindJSON(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + if err := db.MarkInactive(req.SessionID); err != nil { + common.ErrorResp(c, err, 500) + return + } + common.SuccessResp(c) +} diff --git a/server/middlewares/auth.go b/server/middlewares/auth.go index c0743c9c..72eaefe6 100644 --- a/server/middlewares/auth.go +++ b/server/middlewares/auth.go @@ -5,9 +5,11 @@ import ( "fmt" "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/device" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/common" "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" @@ -24,7 +26,9 @@ func Auth(c *gin.Context) { c.Abort() return } - c.Set("user", admin) + if !handleSession(c, admin) { + return + } log.Debugf("use admin token: %+v", admin) c.Next() return @@ -50,7 +54,9 @@ func Auth(c *gin.Context) { } guest.RolesDetail = roles } - c.Set("user", guest) + if !handleSession(c, guest) { + return + } log.Debugf("use empty token: %+v", guest) c.Next() return @@ -87,11 +93,29 @@ func Auth(c *gin.Context) { } user.RolesDetail = roles } - c.Set("user", user) + if !handleSession(c, user) { + return + } log.Debugf("use login token: %+v", user) c.Next() } +func handleSession(c *gin.Context, user *model.User) bool { + clientID := c.GetHeader("Client-Id") + if clientID == "" { + clientID = c.Query("client_id") + } + key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s-%s-%s", user.ID, c.Request.UserAgent(), c.ClientIP(), clientID)) + if err := device.Handle(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil { + common.ErrorResp(c, err, 403) + c.Abort() + return false + } + c.Set("device_key", key) + c.Set("user", user) + return true +} + func Authn(c *gin.Context) { token := c.GetHeader("Authorization") if subtle.ConstantTimeCompare([]byte(token), []byte(setting.GetStr(conf.Token))) == 1 { diff --git a/server/middlewares/session_refresh.go b/server/middlewares/session_refresh.go new file mode 100644 index 00000000..2073020c --- /dev/null +++ b/server/middlewares/session_refresh.go @@ -0,0 +1,26 @@ +package middlewares + +import ( + "github.com/alist-org/alist/v3/internal/device" + "github.com/alist-org/alist/v3/internal/model" + "github.com/gin-gonic/gin" +) + +// SessionRefresh updates session's last_active after successful requests. +func SessionRefresh(c *gin.Context) { + c.Next() + if c.Writer.Status() >= 400 { + return + } + if inactive, ok := c.Get("session_inactive"); ok { + if b, ok := inactive.(bool); ok && b { + return + } + } + userVal, uok := c.Get("user") + keyVal, kok := c.Get("device_key") + if uok && kok { + user := userVal.(*model.User) + device.Refresh(user.ID, keyVal.(string)) + } +} diff --git a/server/router.go b/server/router.go index e8902f7a..4d79c1fd 100644 --- a/server/router.go +++ b/server/router.go @@ -22,6 +22,7 @@ func Init(e *gin.Engine) { }) } Cors(e) + e.Use(middlewares.SessionRefresh) g := e.Group(conf.URL.Path) if conf.Conf.Scheme.HttpPort != -1 && conf.Conf.Scheme.HttpsPort != -1 && conf.Conf.Scheme.ForceHttps { e.Use(middlewares.ForceHttps) @@ -70,6 +71,8 @@ func Init(e *gin.Engine) { auth.POST("/auth/2fa/generate", handles.Generate2FA) auth.POST("/auth/2fa/verify", handles.Verify2FA) auth.GET("/auth/logout", handles.LogOut) + auth.GET("/me/sessions", handles.ListMySessions) + auth.POST("/me/sessions/evict", handles.EvictMySession) // auth api.GET("/auth/sso", handles.SSOLoginRedirect) @@ -184,6 +187,10 @@ func admin(g *gin.RouterGroup) { labelFileBinding.POST("/delete", handles.DelLabelByFileName) labelFileBinding.POST("/restore", handles.RestoreLabelFileBinding) + session := g.Group("/session") + session.GET("/list", handles.ListSessions) + session.POST("/evict", handles.EvictSession) + } func _fs(g *gin.RouterGroup) { diff --git a/server/webdav.go b/server/webdav.go index 582c469d..e0980139 100644 --- a/server/webdav.go +++ b/server/webdav.go @@ -3,6 +3,7 @@ package server import ( "context" "crypto/subtle" + "fmt" "net/http" "net/url" "path" @@ -12,9 +13,11 @@ import ( "github.com/alist-org/alist/v3/server/middlewares" "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/device" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/common" "github.com/alist-org/alist/v3/server/webdav" "github.com/gin-gonic/gin" @@ -69,6 +72,13 @@ func WebDAVAuth(c *gin.Context) { c.Abort() return } + key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", admin.ID, c.ClientIP())) + if err := device.Handle(admin.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil { + c.Status(http.StatusForbidden) + c.Abort() + return + } + c.Set("device_key", key) c.Set("user", admin) c.Next() return @@ -146,6 +156,13 @@ func WebDAVAuth(c *gin.Context) { c.Abort() return } + key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", user.ID, c.ClientIP())) + if err := device.Handle(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil { + c.Status(http.StatusForbidden) + c.Abort() + return + } + c.Set("device_key", key) c.Set("user", user) c.Next() }