From 2d5b8a9a61d18667a67c1cd453462239f0bfc618 Mon Sep 17 00:00:00 2001 From: okatu-loli Date: Tue, 9 Sep 2025 21:45:46 +0800 Subject: [PATCH] feat(auth): Improved device session management logic - Replaced the `userID` parameter with the `user` object to support operations with more user attributes. - Introduced `SessionTTL` and `MaxDevices` properties in the `Handle` and `EnsureActiveOnLogin` functions to support user-defined settings. - Adjusted the session creation and verification logic in `session.go` to support user-defined device count and session duration. - Added help documentation in `setting.go` to explain the configuration purposes of `MaxDevices` and `DeviceSessionTTL`. - Added optional `MaxDevices` and `SessionTTL` properties to the user entity in `user.go` and persisted these settings across user updates. - Modified the device handling logic in `webdav.go` to adapt to the new user object parameters. --- internal/bootstrap/data/setting.go | 4 +-- internal/device/session.go | 43 +++++++++++++++++++++--------- internal/model/user.go | 2 ++ server/handles/auth.go | 2 +- server/handles/user.go | 6 +++++ server/middlewares/auth.go | 2 +- server/webdav.go | 4 +-- 7 files changed, 45 insertions(+), 18 deletions(-) diff --git a/internal/bootstrap/data/setting.go b/internal/bootstrap/data/setting.go index bbb633e3..40baabb7 100644 --- a/internal/bootstrap/data/setting.go +++ b/internal/bootstrap/data/setting.go @@ -165,9 +165,9 @@ func InitialSettings() []model.SettingItem { {Key: conf.ForwardDirectLinkParams, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL}, {Key: conf.IgnoreDirectLinkParams, Value: "sign,alist_ts", Type: conf.TypeString, Group: model.GLOBAL}, {Key: conf.WebauthnLoginEnabled, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL, Flag: model.PUBLIC}, - {Key: conf.MaxDevices, Value: "0", Type: conf.TypeNumber, Group: model.GLOBAL}, + {Key: conf.MaxDevices, Value: "0", Type: conf.TypeNumber, Group: model.GLOBAL, Help: `max devices per user (0 for unlimited)`}, {Key: conf.DeviceEvictPolicy, Value: "deny", Type: conf.TypeSelect, Options: "deny,evict_oldest", Group: model.GLOBAL}, - {Key: conf.DeviceSessionTTL, Value: "86400", Type: conf.TypeNumber, Group: model.GLOBAL}, + {Key: conf.DeviceSessionTTL, Value: "86400", Type: conf.TypeNumber, Group: model.GLOBAL, Help: `session ttl in seconds (0 disables)`}, // single settings {Key: conf.Token, Value: token, Type: conf.TypeString, Group: model.SINGLE, Flag: model.PRIVATE}, diff --git a/internal/device/session.go b/internal/device/session.go index 1d9e7ea5..9ba60a10 100644 --- a/internal/device/session.go +++ b/internal/device/session.go @@ -14,8 +14,11 @@ import ( ) // Handle verifies device sessions for a user and upserts current session. -func Handle(userID uint, deviceKey, ua, ip string) error { +func Handle(user *model.User, deviceKey, ua, ip string) error { ttl := setting.GetInt(conf.DeviceSessionTTL, 86400) + if user.SessionTTL != nil { + ttl = *user.SessionTTL + } if ttl > 0 { _ = db.DeleteSessionsBefore(time.Now().Unix() - int64(ttl)) } @@ -23,7 +26,7 @@ func Handle(userID uint, deviceKey, ua, ip string) error { ip = utils.MaskIP(ip) now := time.Now().Unix() - sess, err := db.GetSession(userID, deviceKey) + sess, err := db.GetSession(user.ID, deviceKey) if err == nil { if sess.Status == model.SessionInactive { return errors.WithStack(errs.SessionInactive) @@ -39,15 +42,18 @@ func Handle(userID uint, deviceKey, ua, ip string) error { } max := setting.GetInt(conf.MaxDevices, 0) + if user.MaxDevices != nil { + max = *user.MaxDevices + } if max > 0 { - count, err := db.CountActiveSessionsByUser(userID) + count, err := db.CountActiveSessionsByUser(user.ID) if err != nil { return err } if count >= int64(max) { policy := setting.GetStr(conf.DeviceEvictPolicy, "deny") if policy == "evict_oldest" { - if oldest, err := db.GetOldestActiveSession(userID); err == nil { + if oldest, err := db.GetOldestActiveSession(user.ID); err == nil { if err := db.MarkInactive(oldest.DeviceKey); err != nil { return err } @@ -58,30 +64,40 @@ func Handle(userID uint, deviceKey, ua, ip string) error { } } - s := &model.Session{UserID: userID, DeviceKey: deviceKey, UserAgent: ua, IP: ip, LastActive: now, Status: model.SessionActive} + s := &model.Session{UserID: user.ID, DeviceKey: deviceKey, UserAgent: ua, IP: ip, LastActive: now, Status: model.SessionActive} return db.CreateSession(s) } // EnsureActiveOnLogin is used only in login flow: // - If session exists (even Inactive): reactivate and refresh fields. // - If not exists: apply max-devices policy, then create Active session. -func EnsureActiveOnLogin(userID uint, deviceKey, ua, ip string) error { +func EnsureActiveOnLogin(user *model.User, deviceKey, ua, ip string) error { + ttl := setting.GetInt(conf.DeviceSessionTTL, 86400) + if user.SessionTTL != nil { + ttl = *user.SessionTTL + } + if ttl > 0 { + _ = db.DeleteSessionsBefore(time.Now().Unix() - int64(ttl)) + } ip = utils.MaskIP(ip) now := time.Now().Unix() - sess, err := db.GetSession(userID, deviceKey) + sess, err := db.GetSession(user.ID, deviceKey) if err == nil { if sess.Status == model.SessionInactive { max := setting.GetInt(conf.MaxDevices, 0) + if user.MaxDevices != nil { + max = *user.MaxDevices + } if max > 0 { - count, err := db.CountActiveSessionsByUser(userID) + count, err := db.CountActiveSessionsByUser(user.ID) if err != nil { return err } if count >= int64(max) { policy := setting.GetStr(conf.DeviceEvictPolicy, "deny") if policy == "evict_oldest" { - if oldest, gerr := db.GetOldestActiveSession(userID); gerr == nil { + if oldest, gerr := db.GetOldestActiveSession(user.ID); gerr == nil { if err := db.MarkInactive(oldest.DeviceKey); err != nil { return err } @@ -103,15 +119,18 @@ func EnsureActiveOnLogin(userID uint, deviceKey, ua, ip string) error { } max := setting.GetInt(conf.MaxDevices, 0) + if user.MaxDevices != nil { + max = *user.MaxDevices + } if max > 0 { - count, err := db.CountActiveSessionsByUser(userID) + count, err := db.CountActiveSessionsByUser(user.ID) if err != nil { return err } if count >= int64(max) { policy := setting.GetStr(conf.DeviceEvictPolicy, "deny") if policy == "evict_oldest" { - if oldest, gerr := db.GetOldestActiveSession(userID); gerr == nil { + if oldest, gerr := db.GetOldestActiveSession(user.ID); gerr == nil { if err := db.MarkInactive(oldest.DeviceKey); err != nil { return err } @@ -123,7 +142,7 @@ func EnsureActiveOnLogin(userID uint, deviceKey, ua, ip string) error { } return db.CreateSession(&model.Session{ - UserID: userID, + UserID: user.ID, DeviceKey: deviceKey, UserAgent: ua, IP: ip, diff --git a/internal/model/user.go b/internal/model/user.go index 8ea1ef1a..8b608619 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -52,6 +52,8 @@ type User struct { Permission int32 `json:"permission"` OtpSecret string `json:"-"` SsoID string `json:"sso_id"` // unique by sso platform + MaxDevices *int `json:"max_devices"` + SessionTTL *int `json:"session_ttl"` Authn string `gorm:"type:text" json:"-"` } diff --git a/server/handles/auth.go b/server/handles/auth.go index dd7d202b..73b2e1b9 100644 --- a/server/handles/auth.go +++ b/server/handles/auth.go @@ -95,7 +95,7 @@ func loginHash(c *gin.Context, req *LoginReq) { key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", user.ID, clientID)) - if err := device.EnsureActiveOnLogin(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil { + if err := device.EnsureActiveOnLogin(user, key, c.Request.UserAgent(), c.ClientIP()); err != nil { if errors.Is(err, errs.TooManyDevices) { common.ErrorResp(c, err, 403) } else { diff --git a/server/handles/user.go b/server/handles/user.go index ac3a06e8..47cc6e99 100644 --- a/server/handles/user.go +++ b/server/handles/user.go @@ -87,6 +87,12 @@ func UpdateUser(c *gin.Context) { if req.OtpSecret == "" { req.OtpSecret = user.OtpSecret } + if req.MaxDevices == nil { + req.MaxDevices = user.MaxDevices + } + if req.SessionTTL == nil { + req.SessionTTL = user.SessionTTL + } if req.Disabled && user.IsAdmin() { count, err := op.CountEnabledAdminsExcluding(user.ID) if err != nil { diff --git a/server/middlewares/auth.go b/server/middlewares/auth.go index 204b4b72..da4f9a1a 100644 --- a/server/middlewares/auth.go +++ b/server/middlewares/auth.go @@ -109,7 +109,7 @@ func HandleSession(c *gin.Context, user *model.User) bool { clientID = c.Query("client_id") } key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", user.ID, clientID)) - if err := device.Handle(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil { + if err := device.Handle(user, key, c.Request.UserAgent(), c.ClientIP()); err != nil { token := c.GetHeader("Authorization") if errors.Is(err, errs.SessionInactive) { _ = common.InvalidateToken(token) diff --git a/server/webdav.go b/server/webdav.go index e0980139..5928811b 100644 --- a/server/webdav.go +++ b/server/webdav.go @@ -73,7 +73,7 @@ func WebDAVAuth(c *gin.Context) { return } key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", admin.ID, c.ClientIP())) - if err := device.Handle(admin.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil { + if err := device.Handle(admin, key, c.Request.UserAgent(), c.ClientIP()); err != nil { c.Status(http.StatusForbidden) c.Abort() return @@ -157,7 +157,7 @@ func WebDAVAuth(c *gin.Context) { return } key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", user.ID, c.ClientIP())) - if err := device.Handle(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil { + if err := device.Handle(user, key, c.Request.UserAgent(), c.ClientIP()); err != nil { c.Status(http.StatusForbidden) c.Abort() return