diff --git a/internal/db/session.go b/internal/db/session.go index e8dce441..35c778c3 100644 --- a/internal/db/session.go +++ b/internal/db/session.go @@ -38,10 +38,12 @@ func DeleteSessionsBefore(ts int64) error { return errors.WithStack(db.Where("last_active < ?", ts).Delete(&model.Session{}).Error) } -func GetOldestSession(userID uint) (*model.Session, error) { +// GetOldestActiveSession returns the oldest active session for the specified user. +func GetOldestActiveSession(userID uint) (*model.Session, error) { var s model.Session - if err := db.Where("user_id = ?", userID).Order("last_active ASC").First(&s).Error; err != nil { - return nil, errors.Wrap(err, "failed get oldest session") + if err := db.Where("user_id = ? AND status = ?", userID, model.SessionActive). + Order("last_active ASC").First(&s).Error; err != nil { + return nil, errors.Wrap(err, "failed get oldest active session") } return &s, nil } diff --git a/internal/device/session.go b/internal/device/session.go index 5d5a3996..1d9e7ea5 100644 --- a/internal/device/session.go +++ b/internal/device/session.go @@ -23,20 +23,68 @@ func Handle(userID uint, deviceKey, ua, ip string) error { ip = utils.MaskIP(ip) now := time.Now().Unix() + sess, err := db.GetSession(userID, deviceKey) + if err == nil { + if sess.Status == model.SessionInactive { + return errors.WithStack(errs.SessionInactive) + } + sess.Status = model.SessionActive + sess.LastActive = now + sess.UserAgent = ua + sess.IP = ip + return db.UpsertSession(sess) + } + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + max := setting.GetInt(conf.MaxDevices, 0) + if max > 0 { + count, err := db.CountActiveSessionsByUser(userID) + if err != nil { + return err + } + if count >= int64(max) { + policy := setting.GetStr(conf.DeviceEvictPolicy, "deny") + if policy == "evict_oldest" { + if oldest, err := db.GetOldestActiveSession(userID); err == nil { + if err := db.MarkInactive(oldest.DeviceKey); err != nil { + return err + } + } + } else { + return errors.WithStack(errs.TooManyDevices) + } + } + } + + s := &model.Session{UserID: userID, DeviceKey: deviceKey, UserAgent: ua, IP: ip, LastActive: now, Status: model.SessionActive} + return db.CreateSession(s) +} + +// EnsureActiveOnLogin is used only in login flow: +// - If session exists (even Inactive): reactivate and refresh fields. +// - If not exists: apply max-devices policy, then create Active session. +func EnsureActiveOnLogin(userID uint, deviceKey, ua, ip string) error { + ip = utils.MaskIP(ip) + now := time.Now().Unix() + sess, err := db.GetSession(userID, deviceKey) if err == nil { if sess.Status == model.SessionInactive { max := setting.GetInt(conf.MaxDevices, 0) if max > 0 { - count, cerr := db.CountActiveSessionsByUser(userID) - if cerr != nil { - return cerr + count, err := db.CountActiveSessionsByUser(userID) + if err != nil { + return err } if count >= int64(max) { policy := setting.GetStr(conf.DeviceEvictPolicy, "deny") if policy == "evict_oldest" { - if oldest, gerr := db.GetOldestSession(userID); gerr == nil { - _ = db.DeleteSession(userID, oldest.DeviceKey) + if oldest, gerr := db.GetOldestActiveSession(userID); gerr == nil { + if err := db.MarkInactive(oldest.DeviceKey); err != nil { + return err + } } } else { return errors.WithStack(errs.TooManyDevices) @@ -63,9 +111,10 @@ func Handle(userID uint, deviceKey, ua, ip string) error { if count >= int64(max) { policy := setting.GetStr(conf.DeviceEvictPolicy, "deny") if policy == "evict_oldest" { - oldest, err := db.GetOldestSession(userID) - if err == nil { - _ = db.DeleteSession(userID, oldest.DeviceKey) + if oldest, gerr := db.GetOldestActiveSession(userID); gerr == nil { + if err := db.MarkInactive(oldest.DeviceKey); err != nil { + return err + } } } else { return errors.WithStack(errs.TooManyDevices) @@ -73,8 +122,14 @@ func Handle(userID uint, deviceKey, ua, ip string) error { } } - s := &model.Session{UserID: userID, DeviceKey: deviceKey, UserAgent: ua, IP: ip, LastActive: now, Status: model.SessionActive} - return db.CreateSession(s) + return db.CreateSession(&model.Session{ + UserID: userID, + DeviceKey: deviceKey, + UserAgent: ua, + IP: ip, + LastActive: now, + Status: model.SessionActive, + }) } // Refresh updates last_active for the session. diff --git a/server/handles/auth.go b/server/handles/auth.go index 8c7d7d9f..dd7d202b 100644 --- a/server/handles/auth.go +++ b/server/handles/auth.go @@ -3,6 +3,8 @@ package handles import ( "bytes" "encoding/base64" + "errors" + "fmt" "image/png" "path" "strings" @@ -10,12 +12,14 @@ import ( "github.com/Xhofe/go-cache" "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/device" + "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/internal/session" "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/common" - "github.com/alist-org/alist/v3/server/middlewares" "github.com/gin-gonic/gin" "github.com/pquerna/otp/totp" ) @@ -83,17 +87,29 @@ func loginHash(c *gin.Context, req *LoginReq) { return } } - // generate device session - if !middlewares.HandleSession(c, user) { + + clientID := c.GetHeader("Client-Id") + if clientID == "" { + clientID = c.Query("client_id") + } + key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", + user.ID, clientID)) + + if err := device.EnsureActiveOnLogin(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil { + if errors.Is(err, errs.TooManyDevices) { + common.ErrorResp(c, err, 403) + } else { + common.ErrorResp(c, err, 400, true) + } return } + // generate token token, err := common.GenerateToken(user) if err != nil { common.ErrorResp(c, err, 400, true) return } - key := c.GetString("device_key") common.SuccessResp(c, gin.H{"token": token, "device_key": key}) loginCache.Del(ip) } diff --git a/server/middlewares/auth.go b/server/middlewares/auth.go index 714c1154..204b4b72 100644 --- a/server/middlewares/auth.go +++ b/server/middlewares/auth.go @@ -2,10 +2,12 @@ package middlewares import ( "crypto/subtle" + "errors" "fmt" "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/device" + "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/internal/setting" @@ -106,9 +108,15 @@ func HandleSession(c *gin.Context, user *model.User) bool { if clientID == "" { clientID = c.Query("client_id") } - key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s-%s-%s", user.ID, c.Request.UserAgent(), c.ClientIP(), clientID)) + key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", user.ID, clientID)) if err := device.Handle(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil { - common.ErrorResp(c, err, 403) + token := c.GetHeader("Authorization") + if errors.Is(err, errs.SessionInactive) { + _ = common.InvalidateToken(token) + common.ErrorResp(c, err, 401) + } else { + common.ErrorResp(c, err, 403) + } c.Abort() return false }