feat(auth): Improved device session management logic

- Replaced the `userID` parameter with the `user` object to support operations with more user attributes.
- Introduced `SessionTTL` and `MaxDevices` properties in the `Handle` and `EnsureActiveOnLogin` functions to support user-defined settings.
- Adjusted the session creation and verification logic in `session.go` to support user-defined device count and session duration.
- Added help documentation in `setting.go` to explain the configuration purposes of `MaxDevices` and `DeviceSessionTTL`.
- Added optional `MaxDevices` and `SessionTTL` properties to the user entity in `user.go` and persisted these settings across user updates.
- Modified the device handling logic in `webdav.go` to adapt to the new user object parameters.
pull/9315/head
okatu-loli 2025-09-09 21:45:46 +08:00
parent fcbc79cb24
commit 2d5b8a9a61
7 changed files with 45 additions and 18 deletions

View File

@ -165,9 +165,9 @@ func InitialSettings() []model.SettingItem {
{Key: conf.ForwardDirectLinkParams, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL}, {Key: conf.ForwardDirectLinkParams, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL},
{Key: conf.IgnoreDirectLinkParams, Value: "sign,alist_ts", Type: conf.TypeString, Group: model.GLOBAL}, {Key: conf.IgnoreDirectLinkParams, Value: "sign,alist_ts", Type: conf.TypeString, Group: model.GLOBAL},
{Key: conf.WebauthnLoginEnabled, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL, Flag: model.PUBLIC}, {Key: conf.WebauthnLoginEnabled, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL, Flag: model.PUBLIC},
{Key: conf.MaxDevices, Value: "0", Type: conf.TypeNumber, Group: model.GLOBAL}, {Key: conf.MaxDevices, Value: "0", Type: conf.TypeNumber, Group: model.GLOBAL, Help: `max devices per user (0 for unlimited)`},
{Key: conf.DeviceEvictPolicy, Value: "deny", Type: conf.TypeSelect, Options: "deny,evict_oldest", Group: model.GLOBAL}, {Key: conf.DeviceEvictPolicy, Value: "deny", Type: conf.TypeSelect, Options: "deny,evict_oldest", Group: model.GLOBAL},
{Key: conf.DeviceSessionTTL, Value: "86400", Type: conf.TypeNumber, Group: model.GLOBAL}, {Key: conf.DeviceSessionTTL, Value: "86400", Type: conf.TypeNumber, Group: model.GLOBAL, Help: `session ttl in seconds (0 disables)`},
// single settings // single settings
{Key: conf.Token, Value: token, Type: conf.TypeString, Group: model.SINGLE, Flag: model.PRIVATE}, {Key: conf.Token, Value: token, Type: conf.TypeString, Group: model.SINGLE, Flag: model.PRIVATE},

View File

@ -14,8 +14,11 @@ import (
) )
// Handle verifies device sessions for a user and upserts current session. // Handle verifies device sessions for a user and upserts current session.
func Handle(userID uint, deviceKey, ua, ip string) error { func Handle(user *model.User, deviceKey, ua, ip string) error {
ttl := setting.GetInt(conf.DeviceSessionTTL, 86400) ttl := setting.GetInt(conf.DeviceSessionTTL, 86400)
if user.SessionTTL != nil {
ttl = *user.SessionTTL
}
if ttl > 0 { if ttl > 0 {
_ = db.DeleteSessionsBefore(time.Now().Unix() - int64(ttl)) _ = db.DeleteSessionsBefore(time.Now().Unix() - int64(ttl))
} }
@ -23,7 +26,7 @@ func Handle(userID uint, deviceKey, ua, ip string) error {
ip = utils.MaskIP(ip) ip = utils.MaskIP(ip)
now := time.Now().Unix() now := time.Now().Unix()
sess, err := db.GetSession(userID, deviceKey) sess, err := db.GetSession(user.ID, deviceKey)
if err == nil { if err == nil {
if sess.Status == model.SessionInactive { if sess.Status == model.SessionInactive {
return errors.WithStack(errs.SessionInactive) return errors.WithStack(errs.SessionInactive)
@ -39,15 +42,18 @@ func Handle(userID uint, deviceKey, ua, ip string) error {
} }
max := setting.GetInt(conf.MaxDevices, 0) max := setting.GetInt(conf.MaxDevices, 0)
if user.MaxDevices != nil {
max = *user.MaxDevices
}
if max > 0 { if max > 0 {
count, err := db.CountActiveSessionsByUser(userID) count, err := db.CountActiveSessionsByUser(user.ID)
if err != nil { if err != nil {
return err return err
} }
if count >= int64(max) { if count >= int64(max) {
policy := setting.GetStr(conf.DeviceEvictPolicy, "deny") policy := setting.GetStr(conf.DeviceEvictPolicy, "deny")
if policy == "evict_oldest" { if policy == "evict_oldest" {
if oldest, err := db.GetOldestActiveSession(userID); err == nil { if oldest, err := db.GetOldestActiveSession(user.ID); err == nil {
if err := db.MarkInactive(oldest.DeviceKey); err != nil { if err := db.MarkInactive(oldest.DeviceKey); err != nil {
return err return err
} }
@ -58,30 +64,40 @@ func Handle(userID uint, deviceKey, ua, ip string) error {
} }
} }
s := &model.Session{UserID: userID, DeviceKey: deviceKey, UserAgent: ua, IP: ip, LastActive: now, Status: model.SessionActive} s := &model.Session{UserID: user.ID, DeviceKey: deviceKey, UserAgent: ua, IP: ip, LastActive: now, Status: model.SessionActive}
return db.CreateSession(s) return db.CreateSession(s)
} }
// EnsureActiveOnLogin is used only in login flow: // EnsureActiveOnLogin is used only in login flow:
// - If session exists (even Inactive): reactivate and refresh fields. // - If session exists (even Inactive): reactivate and refresh fields.
// - If not exists: apply max-devices policy, then create Active session. // - If not exists: apply max-devices policy, then create Active session.
func EnsureActiveOnLogin(userID uint, deviceKey, ua, ip string) error { func EnsureActiveOnLogin(user *model.User, deviceKey, ua, ip string) error {
ttl := setting.GetInt(conf.DeviceSessionTTL, 86400)
if user.SessionTTL != nil {
ttl = *user.SessionTTL
}
if ttl > 0 {
_ = db.DeleteSessionsBefore(time.Now().Unix() - int64(ttl))
}
ip = utils.MaskIP(ip) ip = utils.MaskIP(ip)
now := time.Now().Unix() now := time.Now().Unix()
sess, err := db.GetSession(userID, deviceKey) sess, err := db.GetSession(user.ID, deviceKey)
if err == nil { if err == nil {
if sess.Status == model.SessionInactive { if sess.Status == model.SessionInactive {
max := setting.GetInt(conf.MaxDevices, 0) max := setting.GetInt(conf.MaxDevices, 0)
if user.MaxDevices != nil {
max = *user.MaxDevices
}
if max > 0 { if max > 0 {
count, err := db.CountActiveSessionsByUser(userID) count, err := db.CountActiveSessionsByUser(user.ID)
if err != nil { if err != nil {
return err return err
} }
if count >= int64(max) { if count >= int64(max) {
policy := setting.GetStr(conf.DeviceEvictPolicy, "deny") policy := setting.GetStr(conf.DeviceEvictPolicy, "deny")
if policy == "evict_oldest" { if policy == "evict_oldest" {
if oldest, gerr := db.GetOldestActiveSession(userID); gerr == nil { if oldest, gerr := db.GetOldestActiveSession(user.ID); gerr == nil {
if err := db.MarkInactive(oldest.DeviceKey); err != nil { if err := db.MarkInactive(oldest.DeviceKey); err != nil {
return err return err
} }
@ -103,15 +119,18 @@ func EnsureActiveOnLogin(userID uint, deviceKey, ua, ip string) error {
} }
max := setting.GetInt(conf.MaxDevices, 0) max := setting.GetInt(conf.MaxDevices, 0)
if user.MaxDevices != nil {
max = *user.MaxDevices
}
if max > 0 { if max > 0 {
count, err := db.CountActiveSessionsByUser(userID) count, err := db.CountActiveSessionsByUser(user.ID)
if err != nil { if err != nil {
return err return err
} }
if count >= int64(max) { if count >= int64(max) {
policy := setting.GetStr(conf.DeviceEvictPolicy, "deny") policy := setting.GetStr(conf.DeviceEvictPolicy, "deny")
if policy == "evict_oldest" { if policy == "evict_oldest" {
if oldest, gerr := db.GetOldestActiveSession(userID); gerr == nil { if oldest, gerr := db.GetOldestActiveSession(user.ID); gerr == nil {
if err := db.MarkInactive(oldest.DeviceKey); err != nil { if err := db.MarkInactive(oldest.DeviceKey); err != nil {
return err return err
} }
@ -123,7 +142,7 @@ func EnsureActiveOnLogin(userID uint, deviceKey, ua, ip string) error {
} }
return db.CreateSession(&model.Session{ return db.CreateSession(&model.Session{
UserID: userID, UserID: user.ID,
DeviceKey: deviceKey, DeviceKey: deviceKey,
UserAgent: ua, UserAgent: ua,
IP: ip, IP: ip,

View File

@ -52,6 +52,8 @@ type User struct {
Permission int32 `json:"permission"` Permission int32 `json:"permission"`
OtpSecret string `json:"-"` OtpSecret string `json:"-"`
SsoID string `json:"sso_id"` // unique by sso platform SsoID string `json:"sso_id"` // unique by sso platform
MaxDevices *int `json:"max_devices"`
SessionTTL *int `json:"session_ttl"`
Authn string `gorm:"type:text" json:"-"` Authn string `gorm:"type:text" json:"-"`
} }

View File

@ -95,7 +95,7 @@ func loginHash(c *gin.Context, req *LoginReq) {
key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s",
user.ID, clientID)) user.ID, clientID))
if err := device.EnsureActiveOnLogin(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil { if err := device.EnsureActiveOnLogin(user, key, c.Request.UserAgent(), c.ClientIP()); err != nil {
if errors.Is(err, errs.TooManyDevices) { if errors.Is(err, errs.TooManyDevices) {
common.ErrorResp(c, err, 403) common.ErrorResp(c, err, 403)
} else { } else {

View File

@ -87,6 +87,12 @@ func UpdateUser(c *gin.Context) {
if req.OtpSecret == "" { if req.OtpSecret == "" {
req.OtpSecret = user.OtpSecret req.OtpSecret = user.OtpSecret
} }
if req.MaxDevices == nil {
req.MaxDevices = user.MaxDevices
}
if req.SessionTTL == nil {
req.SessionTTL = user.SessionTTL
}
if req.Disabled && user.IsAdmin() { if req.Disabled && user.IsAdmin() {
count, err := op.CountEnabledAdminsExcluding(user.ID) count, err := op.CountEnabledAdminsExcluding(user.ID)
if err != nil { if err != nil {

View File

@ -109,7 +109,7 @@ func HandleSession(c *gin.Context, user *model.User) bool {
clientID = c.Query("client_id") clientID = c.Query("client_id")
} }
key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", user.ID, clientID)) key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", user.ID, clientID))
if err := device.Handle(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil { if err := device.Handle(user, key, c.Request.UserAgent(), c.ClientIP()); err != nil {
token := c.GetHeader("Authorization") token := c.GetHeader("Authorization")
if errors.Is(err, errs.SessionInactive) { if errors.Is(err, errs.SessionInactive) {
_ = common.InvalidateToken(token) _ = common.InvalidateToken(token)

View File

@ -73,7 +73,7 @@ func WebDAVAuth(c *gin.Context) {
return return
} }
key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", admin.ID, c.ClientIP())) key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", admin.ID, c.ClientIP()))
if err := device.Handle(admin.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil { if err := device.Handle(admin, key, c.Request.UserAgent(), c.ClientIP()); err != nil {
c.Status(http.StatusForbidden) c.Status(http.StatusForbidden)
c.Abort() c.Abort()
return return
@ -157,7 +157,7 @@ func WebDAVAuth(c *gin.Context) {
return return
} }
key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", user.ID, c.ClientIP())) key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", user.ID, c.ClientIP()))
if err := device.Handle(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil { if err := device.Handle(user, key, c.Request.UserAgent(), c.ClientIP()); err != nil {
c.Status(http.StatusForbidden) c.Status(http.StatusForbidden)
c.Abort() c.Abort()
return return