From 6f0cc7a3e85286bc13610021c2a03c618e89b6d6 Mon Sep 17 00:00:00 2001 From: okatu-loli Date: Fri, 29 Aug 2025 17:32:46 +0800 Subject: [PATCH] feat(auth): Enhanced device login session management - Upon login, obtain and verify `Client-Id` to ensure unique device sessions. - If there are too many device sessions, clean up old ones according to the configured policy or return an error. - If a device session is invalid, deregister the old token and return a 401 error. - Added `EnsureActiveOnLogin` function to handle the creation and refresh of device sessions during login. --- internal/device/session.go | 65 ++++++++++++++++++++++++++++---------- server/handles/auth.go | 24 +++++++++++--- server/middlewares/auth.go | 12 +++++-- 3 files changed, 78 insertions(+), 23 deletions(-) diff --git a/internal/device/session.go b/internal/device/session.go index 5d5a3996..eb5e8647 100644 --- a/internal/device/session.go +++ b/internal/device/session.go @@ -26,23 +26,7 @@ func Handle(userID uint, deviceKey, ua, ip string) error { sess, err := db.GetSession(userID, deviceKey) if err == nil { if sess.Status == model.SessionInactive { - max := setting.GetInt(conf.MaxDevices, 0) - if max > 0 { - count, cerr := db.CountActiveSessionsByUser(userID) - if cerr != nil { - return cerr - } - if count >= int64(max) { - policy := setting.GetStr(conf.DeviceEvictPolicy, "deny") - if policy == "evict_oldest" { - if oldest, gerr := db.GetOldestSession(userID); gerr == nil { - _ = db.DeleteSession(userID, oldest.DeviceKey) - } - } else { - return errors.WithStack(errs.TooManyDevices) - } - } - } + return errors.WithStack(errs.SessionInactive) } sess.Status = model.SessionActive sess.LastActive = now @@ -77,6 +61,53 @@ func Handle(userID uint, deviceKey, ua, ip string) error { return db.CreateSession(s) } +// EnsureActiveOnLogin is used only in login flow: +// - If session exists (even Inactive): reactivate and refresh fields. +// - If not exists: apply max-devices policy, then create Active session. +func EnsureActiveOnLogin(userID uint, deviceKey, ua, ip string) error { + ip = utils.MaskIP(ip) + now := time.Now().Unix() + + sess, err := db.GetSession(userID, deviceKey) + if err == nil { + sess.Status = model.SessionActive + sess.LastActive = now + sess.UserAgent = ua + sess.IP = ip + return db.UpsertSession(sess) + } + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + max := setting.GetInt(conf.MaxDevices, 0) + if max > 0 { + count, err := db.CountActiveSessionsByUser(userID) + if err != nil { + return err + } + if count >= int64(max) { + policy := setting.GetStr(conf.DeviceEvictPolicy, "deny") + if policy == "evict_oldest" { + if oldest, gerr := db.GetOldestSession(userID); gerr == nil { + _ = db.DeleteSession(userID, oldest.DeviceKey) + } + } else { + return errors.WithStack(errs.TooManyDevices) + } + } + } + + return db.CreateSession(&model.Session{ + UserID: userID, + DeviceKey: deviceKey, + UserAgent: ua, + IP: ip, + LastActive: now, + Status: model.SessionActive, + }) +} + // Refresh updates last_active for the session. func Refresh(userID uint, deviceKey string) { _ = db.UpdateSessionLastActive(userID, deviceKey, time.Now().Unix()) diff --git a/server/handles/auth.go b/server/handles/auth.go index 8c7d7d9f..dd7d202b 100644 --- a/server/handles/auth.go +++ b/server/handles/auth.go @@ -3,6 +3,8 @@ package handles import ( "bytes" "encoding/base64" + "errors" + "fmt" "image/png" "path" "strings" @@ -10,12 +12,14 @@ import ( "github.com/Xhofe/go-cache" "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/device" + "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/internal/session" "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/common" - "github.com/alist-org/alist/v3/server/middlewares" "github.com/gin-gonic/gin" "github.com/pquerna/otp/totp" ) @@ -83,17 +87,29 @@ func loginHash(c *gin.Context, req *LoginReq) { return } } - // generate device session - if !middlewares.HandleSession(c, user) { + + clientID := c.GetHeader("Client-Id") + if clientID == "" { + clientID = c.Query("client_id") + } + key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", + user.ID, clientID)) + + if err := device.EnsureActiveOnLogin(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil { + if errors.Is(err, errs.TooManyDevices) { + common.ErrorResp(c, err, 403) + } else { + common.ErrorResp(c, err, 400, true) + } return } + // generate token token, err := common.GenerateToken(user) if err != nil { common.ErrorResp(c, err, 400, true) return } - key := c.GetString("device_key") common.SuccessResp(c, gin.H{"token": token, "device_key": key}) loginCache.Del(ip) } diff --git a/server/middlewares/auth.go b/server/middlewares/auth.go index 714c1154..204b4b72 100644 --- a/server/middlewares/auth.go +++ b/server/middlewares/auth.go @@ -2,10 +2,12 @@ package middlewares import ( "crypto/subtle" + "errors" "fmt" "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/device" + "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/internal/setting" @@ -106,9 +108,15 @@ func HandleSession(c *gin.Context, user *model.User) bool { if clientID == "" { clientID = c.Query("client_id") } - key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s-%s-%s", user.ID, c.Request.UserAgent(), c.ClientIP(), clientID)) + key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", user.ID, clientID)) if err := device.Handle(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil { - common.ErrorResp(c, err, 403) + token := c.GetHeader("Authorization") + if errors.Is(err, errs.SessionInactive) { + _ = common.InvalidateToken(token) + common.ErrorResp(c, err, 401) + } else { + common.ErrorResp(c, err, 403) + } c.Abort() return false }