diff --git a/internal/db/session.go b/internal/db/session.go index 8db9fa69..e8dce441 100644 --- a/internal/db/session.go +++ b/internal/db/session.go @@ -26,9 +26,11 @@ func DeleteSession(userID uint, deviceKey string) error { return errors.WithStack(db.Where("user_id = ? AND device_key = ?", userID, deviceKey).Delete(&model.Session{}).Error) } -func CountSessionsByUser(userID uint) (int64, error) { +func CountActiveSessionsByUser(userID uint) (int64, error) { var count int64 - err := db.Model(&model.Session{}).Where("user_id = ?", userID).Count(&count).Error + err := db.Model(&model.Session{}). + Where("user_id = ? AND status = ?", userID, model.SessionActive). + Count(&count).Error return count, errors.WithStack(err) } diff --git a/internal/device/session.go b/internal/device/session.go index 49bf74b6..5d5a3996 100644 --- a/internal/device/session.go +++ b/internal/device/session.go @@ -25,7 +25,25 @@ func Handle(userID uint, deviceKey, ua, ip string) error { now := time.Now().Unix() sess, err := db.GetSession(userID, deviceKey) if err == nil { - // reactivate existing session if it was inactive + if sess.Status == model.SessionInactive { + max := setting.GetInt(conf.MaxDevices, 0) + if max > 0 { + count, cerr := db.CountActiveSessionsByUser(userID) + if cerr != nil { + return cerr + } + if count >= int64(max) { + policy := setting.GetStr(conf.DeviceEvictPolicy, "deny") + if policy == "evict_oldest" { + if oldest, gerr := db.GetOldestSession(userID); gerr == nil { + _ = db.DeleteSession(userID, oldest.DeviceKey) + } + } else { + return errors.WithStack(errs.TooManyDevices) + } + } + } + } sess.Status = model.SessionActive sess.LastActive = now sess.UserAgent = ua @@ -38,7 +56,7 @@ func Handle(userID uint, deviceKey, ua, ip string) error { max := setting.GetInt(conf.MaxDevices, 0) if max > 0 { - count, err := db.CountSessionsByUser(userID) + count, err := db.CountActiveSessionsByUser(userID) if err != nil { return err } diff --git a/server/handles/auth.go b/server/handles/auth.go index 30714f65..8c7d7d9f 100644 --- a/server/handles/auth.go +++ b/server/handles/auth.go @@ -15,6 +15,7 @@ import ( "github.com/alist-org/alist/v3/internal/session" "github.com/alist-org/alist/v3/internal/setting" "github.com/alist-org/alist/v3/server/common" + "github.com/alist-org/alist/v3/server/middlewares" "github.com/gin-gonic/gin" "github.com/pquerna/otp/totp" ) @@ -82,13 +83,18 @@ func loginHash(c *gin.Context, req *LoginReq) { return } } + // generate device session + if !middlewares.HandleSession(c, user) { + return + } // generate token token, err := common.GenerateToken(user) if err != nil { common.ErrorResp(c, err, 400, true) return } - common.SuccessResp(c, gin.H{"token": token}) + key := c.GetString("device_key") + common.SuccessResp(c, gin.H{"token": token, "device_key": key}) loginCache.Del(ip) } diff --git a/server/middlewares/auth.go b/server/middlewares/auth.go index 72eaefe6..714c1154 100644 --- a/server/middlewares/auth.go +++ b/server/middlewares/auth.go @@ -26,7 +26,7 @@ func Auth(c *gin.Context) { c.Abort() return } - if !handleSession(c, admin) { + if !HandleSession(c, admin) { return } log.Debugf("use admin token: %+v", admin) @@ -54,7 +54,7 @@ func Auth(c *gin.Context) { } guest.RolesDetail = roles } - if !handleSession(c, guest) { + if !HandleSession(c, guest) { return } log.Debugf("use empty token: %+v", guest) @@ -93,14 +93,15 @@ func Auth(c *gin.Context) { } user.RolesDetail = roles } - if !handleSession(c, user) { + if !HandleSession(c, user) { return } log.Debugf("use login token: %+v", user) c.Next() } -func handleSession(c *gin.Context, user *model.User) bool { +// HandleSession verifies device sessions and stores context values. +func HandleSession(c *gin.Context, user *model.User) bool { clientID := c.GetHeader("Client-Id") if clientID == "" { clientID = c.Query("client_id")