mirror of https://github.com/Xhofe/alist
				
				
				
			feat: implement session management (#9286)
* feat(auth): Added device session management - Added the `handleSession` function to manage user device sessions and verify client identity - Updated `auth.go` to call `handleSession` for device handling when a user logs in - Added the `Session` model to database migrations - Added `device.go` and `session.go` files to handle device session logic - Updated `settings.go` to add device-related configuration items, such as the maximum number of devices, device eviction policy, and session TTL * feat(session): Adds session management features - Added `SessionInactive` error type in `device.go` - Added session-related APIs in `router.go` to support listing and evicting sessions - Added `ListSessionsByUser`, `ListSessions`, and `MarkInactive` methods in `session.go` - Returns an appropriate error when the session state is `SessionInactive` * feat(auth): Marks the device session as invalid. - Import the `session` package into the `auth` module to handle device session status. - Add a check in the login logic. If `device_key` is obtained, call `session.MarkInactive` to mark the device session as invalid. - Store the invalid status in the context variable `session_inactive` for subsequent middleware checks. - Add a check in the session refresh logic to abort the process if the current session has been marked invalid. * feat(auth, session): Added device information processing and session management changes - Updated device handling logic in `auth.go` to pass user agent and IP information - Adjusted database queries in `session.go` to optimize session query fields and add `user_agent` and `ip` fields - Modified the `Handle` method to add `ua` and `ip` parameters to store the user agent and IP address - Added the `SessionResp` structure to return a session response containing `user_agent` and `ip` - Updated the `/admin/user/create` and `/webdav` endpoints to pass the user agent and IP address to the device handlerpull/8511/merge
							parent
							
								
									3319f6ea6a
								
							
						
					
					
						commit
						c64f899a63
					
				| 
						 | 
					@ -165,6 +165,9 @@ func InitialSettings() []model.SettingItem {
 | 
				
			||||||
		{Key: conf.ForwardDirectLinkParams, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL},
 | 
							{Key: conf.ForwardDirectLinkParams, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL},
 | 
				
			||||||
		{Key: conf.IgnoreDirectLinkParams, Value: "sign,alist_ts", Type: conf.TypeString, Group: model.GLOBAL},
 | 
							{Key: conf.IgnoreDirectLinkParams, Value: "sign,alist_ts", Type: conf.TypeString, Group: model.GLOBAL},
 | 
				
			||||||
		{Key: conf.WebauthnLoginEnabled, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL, Flag: model.PUBLIC},
 | 
							{Key: conf.WebauthnLoginEnabled, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL, Flag: model.PUBLIC},
 | 
				
			||||||
 | 
							{Key: conf.MaxDevices, Value: "0", Type: conf.TypeNumber, Group: model.GLOBAL},
 | 
				
			||||||
 | 
							{Key: conf.DeviceEvictPolicy, Value: "deny", Type: conf.TypeSelect, Options: "deny,evict_oldest", Group: model.GLOBAL},
 | 
				
			||||||
 | 
							{Key: conf.DeviceSessionTTL, Value: "86400", Type: conf.TypeNumber, Group: model.GLOBAL},
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// single settings
 | 
							// single settings
 | 
				
			||||||
		{Key: conf.Token, Value: token, Type: conf.TypeString, Group: model.SINGLE, Flag: model.PRIVATE},
 | 
							{Key: conf.Token, Value: token, Type: conf.TypeString, Group: model.SINGLE, Flag: model.PRIVATE},
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -48,6 +48,9 @@ const (
 | 
				
			||||||
	ForwardDirectLinkParams = "forward_direct_link_params"
 | 
						ForwardDirectLinkParams = "forward_direct_link_params"
 | 
				
			||||||
	IgnoreDirectLinkParams  = "ignore_direct_link_params"
 | 
						IgnoreDirectLinkParams  = "ignore_direct_link_params"
 | 
				
			||||||
	WebauthnLoginEnabled    = "webauthn_login_enabled"
 | 
						WebauthnLoginEnabled    = "webauthn_login_enabled"
 | 
				
			||||||
 | 
						MaxDevices              = "max_devices"
 | 
				
			||||||
 | 
						DeviceEvictPolicy       = "device_evict_policy"
 | 
				
			||||||
 | 
						DeviceSessionTTL        = "device_session_ttl"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// index
 | 
						// index
 | 
				
			||||||
	SearchIndex     = "search_index"
 | 
						SearchIndex     = "search_index"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -12,7 +12,7 @@ var db *gorm.DB
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func Init(d *gorm.DB) {
 | 
					func Init(d *gorm.DB) {
 | 
				
			||||||
	db = d
 | 
						db = d
 | 
				
			||||||
	err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.Role), new(model.Label), new(model.LabelFileBinding), new(model.ObjFile))
 | 
						err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.Role), new(model.Label), new(model.LabelFileBinding), new(model.ObjFile), new(model.Session))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.Fatalf("failed migrate database: %s", err.Error())
 | 
							log.Fatalf("failed migrate database: %s", err.Error())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,65 @@
 | 
				
			||||||
 | 
					package db
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/alist-org/alist/v3/internal/model"
 | 
				
			||||||
 | 
						"github.com/pkg/errors"
 | 
				
			||||||
 | 
						"gorm.io/gorm/clause"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func GetSession(userID uint, deviceKey string) (*model.Session, error) {
 | 
				
			||||||
 | 
						s := model.Session{UserID: userID, DeviceKey: deviceKey}
 | 
				
			||||||
 | 
						if err := db.Select("user_id, device_key, last_active, status, user_agent, ip").Where(&s).First(&s).Error; err != nil {
 | 
				
			||||||
 | 
							return nil, errors.Wrap(err, "failed find session")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &s, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func CreateSession(s *model.Session) error {
 | 
				
			||||||
 | 
						return errors.WithStack(db.Create(s).Error)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func UpsertSession(s *model.Session) error {
 | 
				
			||||||
 | 
						return errors.WithStack(db.Clauses(clause.OnConflict{UpdateAll: true}).Create(s).Error)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func DeleteSession(userID uint, deviceKey string) error {
 | 
				
			||||||
 | 
						return errors.WithStack(db.Where("user_id = ? AND device_key = ?", userID, deviceKey).Delete(&model.Session{}).Error)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func CountSessionsByUser(userID uint) (int64, error) {
 | 
				
			||||||
 | 
						var count int64
 | 
				
			||||||
 | 
						err := db.Model(&model.Session{}).Where("user_id = ?", userID).Count(&count).Error
 | 
				
			||||||
 | 
						return count, errors.WithStack(err)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func DeleteSessionsBefore(ts int64) error {
 | 
				
			||||||
 | 
						return errors.WithStack(db.Where("last_active < ?", ts).Delete(&model.Session{}).Error)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func GetOldestSession(userID uint) (*model.Session, error) {
 | 
				
			||||||
 | 
						var s model.Session
 | 
				
			||||||
 | 
						if err := db.Where("user_id = ?", userID).Order("last_active ASC").First(&s).Error; err != nil {
 | 
				
			||||||
 | 
							return nil, errors.Wrap(err, "failed get oldest session")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &s, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func UpdateSessionLastActive(userID uint, deviceKey string, lastActive int64) error {
 | 
				
			||||||
 | 
						return errors.WithStack(db.Model(&model.Session{}).Where("user_id = ? AND device_key = ?", userID, deviceKey).Update("last_active", lastActive).Error)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func ListSessionsByUser(userID uint) ([]model.Session, error) {
 | 
				
			||||||
 | 
						var sessions []model.Session
 | 
				
			||||||
 | 
						err := db.Select("user_id, device_key, last_active, status, user_agent, ip").Where("user_id = ? AND status = ?", userID, model.SessionActive).Find(&sessions).Error
 | 
				
			||||||
 | 
						return sessions, errors.WithStack(err)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func ListSessions() ([]model.Session, error) {
 | 
				
			||||||
 | 
						var sessions []model.Session
 | 
				
			||||||
 | 
						err := db.Select("user_id, device_key, last_active, status, user_agent, ip").Where("status = ?", model.SessionActive).Find(&sessions).Error
 | 
				
			||||||
 | 
						return sessions, errors.WithStack(err)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func MarkInactive(sessionID string) error {
 | 
				
			||||||
 | 
						return errors.WithStack(db.Model(&model.Session{}).Where("device_key = ?", sessionID).Update("status", model.SessionInactive).Error)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,67 @@
 | 
				
			||||||
 | 
					package device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/alist-org/alist/v3/internal/conf"
 | 
				
			||||||
 | 
						"github.com/alist-org/alist/v3/internal/db"
 | 
				
			||||||
 | 
						"github.com/alist-org/alist/v3/internal/errs"
 | 
				
			||||||
 | 
						"github.com/alist-org/alist/v3/internal/model"
 | 
				
			||||||
 | 
						"github.com/alist-org/alist/v3/internal/setting"
 | 
				
			||||||
 | 
						"github.com/alist-org/alist/v3/pkg/utils"
 | 
				
			||||||
 | 
						"github.com/pkg/errors"
 | 
				
			||||||
 | 
						"gorm.io/gorm"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Handle verifies device sessions for a user and upserts current session.
 | 
				
			||||||
 | 
					func Handle(userID uint, deviceKey, ua, ip string) error {
 | 
				
			||||||
 | 
						ttl := setting.GetInt(conf.DeviceSessionTTL, 86400)
 | 
				
			||||||
 | 
						if ttl > 0 {
 | 
				
			||||||
 | 
							_ = db.DeleteSessionsBefore(time.Now().Unix() - int64(ttl))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ip = utils.MaskIP(ip)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						now := time.Now().Unix()
 | 
				
			||||||
 | 
						sess, err := db.GetSession(userID, deviceKey)
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							if sess.Status == model.SessionInactive {
 | 
				
			||||||
 | 
								return errors.WithStack(errs.SessionInactive)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							sess.LastActive = now
 | 
				
			||||||
 | 
							sess.Status = model.SessionActive
 | 
				
			||||||
 | 
							sess.UserAgent = ua
 | 
				
			||||||
 | 
							sess.IP = ip
 | 
				
			||||||
 | 
							return db.UpsertSession(sess)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						max := setting.GetInt(conf.MaxDevices, 0)
 | 
				
			||||||
 | 
						if max > 0 {
 | 
				
			||||||
 | 
							count, err := db.CountSessionsByUser(userID)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if count >= int64(max) {
 | 
				
			||||||
 | 
								policy := setting.GetStr(conf.DeviceEvictPolicy, "deny")
 | 
				
			||||||
 | 
								if policy == "evict_oldest" {
 | 
				
			||||||
 | 
									oldest, err := db.GetOldestSession(userID)
 | 
				
			||||||
 | 
									if err == nil {
 | 
				
			||||||
 | 
										_ = db.DeleteSession(userID, oldest.DeviceKey)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									return errors.WithStack(errs.TooManyDevices)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						s := &model.Session{UserID: userID, DeviceKey: deviceKey, UserAgent: ua, IP: ip, LastActive: now, Status: model.SessionActive}
 | 
				
			||||||
 | 
						return db.CreateSession(s)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Refresh updates last_active for the session.
 | 
				
			||||||
 | 
					func Refresh(userID uint, deviceKey string) {
 | 
				
			||||||
 | 
						_ = db.UpdateSessionLastActive(userID, deviceKey, time.Now().Unix())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,8 @@
 | 
				
			||||||
 | 
					package errs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "errors"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var (
 | 
				
			||||||
 | 
						TooManyDevices  = errors.New("too many active devices")
 | 
				
			||||||
 | 
						SessionInactive = errors.New("session inactive")
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,16 @@
 | 
				
			||||||
 | 
					package model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Session represents a device session of a user.
 | 
				
			||||||
 | 
					type Session struct {
 | 
				
			||||||
 | 
						UserID     uint   `json:"user_id" gorm:"index"`
 | 
				
			||||||
 | 
						DeviceKey  string `json:"device_key" gorm:"primaryKey;size:64"`
 | 
				
			||||||
 | 
						UserAgent  string `json:"user_agent" gorm:"size:255"`
 | 
				
			||||||
 | 
						IP         string `json:"ip" gorm:"size:64"`
 | 
				
			||||||
 | 
						LastActive int64  `json:"last_active"`
 | 
				
			||||||
 | 
						Status     int    `json:"status"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						SessionActive = iota
 | 
				
			||||||
 | 
						SessionInactive
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,8 @@
 | 
				
			||||||
 | 
					package session
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "github.com/alist-org/alist/v3/internal/db"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// MarkInactive marks the session with the given ID as inactive.
 | 
				
			||||||
 | 
					func MarkInactive(sessionID string) error {
 | 
				
			||||||
 | 
						return db.MarkInactive(sessionID)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,30 @@
 | 
				
			||||||
 | 
					package utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// MaskIP anonymizes middle segments of an IP address.
 | 
				
			||||||
 | 
					func MaskIP(ip string) string {
 | 
				
			||||||
 | 
						if ip == "" {
 | 
				
			||||||
 | 
							return ""
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if strings.Contains(ip, ":") {
 | 
				
			||||||
 | 
							parts := strings.Split(ip, ":")
 | 
				
			||||||
 | 
							if len(parts) > 2 {
 | 
				
			||||||
 | 
								for i := 1; i < len(parts)-1; i++ {
 | 
				
			||||||
 | 
									if parts[i] != "" {
 | 
				
			||||||
 | 
										parts[i] = "*"
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return strings.Join(parts, ":")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return ip
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						parts := strings.Split(ip, ".")
 | 
				
			||||||
 | 
						if len(parts) == 4 {
 | 
				
			||||||
 | 
							for i := 1; i < len(parts)-1; i++ {
 | 
				
			||||||
 | 
								parts[i] = "*"
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return strings.Join(parts, ".")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return ip
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -12,6 +12,7 @@ import (
 | 
				
			||||||
	"github.com/alist-org/alist/v3/internal/conf"
 | 
						"github.com/alist-org/alist/v3/internal/conf"
 | 
				
			||||||
	"github.com/alist-org/alist/v3/internal/model"
 | 
						"github.com/alist-org/alist/v3/internal/model"
 | 
				
			||||||
	"github.com/alist-org/alist/v3/internal/op"
 | 
						"github.com/alist-org/alist/v3/internal/op"
 | 
				
			||||||
 | 
						"github.com/alist-org/alist/v3/internal/session"
 | 
				
			||||||
	"github.com/alist-org/alist/v3/internal/setting"
 | 
						"github.com/alist-org/alist/v3/internal/setting"
 | 
				
			||||||
	"github.com/alist-org/alist/v3/server/common"
 | 
						"github.com/alist-org/alist/v3/server/common"
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
| 
						 | 
					@ -247,6 +248,13 @@ func Verify2FA(c *gin.Context) {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func LogOut(c *gin.Context) {
 | 
					func LogOut(c *gin.Context) {
 | 
				
			||||||
 | 
						if keyVal, ok := c.Get("device_key"); ok {
 | 
				
			||||||
 | 
							if err := session.MarkInactive(keyVal.(string)); err != nil {
 | 
				
			||||||
 | 
								common.ErrorResp(c, err, 500)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							c.Set("session_inactive", true)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	err := common.InvalidateToken(c.GetHeader("Authorization"))
 | 
						err := common.InvalidateToken(c.GetHeader("Authorization"))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		common.ErrorResp(c, err, 500)
 | 
							common.ErrorResp(c, err, 500)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,92 @@
 | 
				
			||||||
 | 
					package handles
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/alist-org/alist/v3/internal/db"
 | 
				
			||||||
 | 
						"github.com/alist-org/alist/v3/internal/model"
 | 
				
			||||||
 | 
						"github.com/alist-org/alist/v3/server/common"
 | 
				
			||||||
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type SessionResp struct {
 | 
				
			||||||
 | 
						SessionID  string `json:"session_id"`
 | 
				
			||||||
 | 
						UserID     uint   `json:"user_id,omitempty"`
 | 
				
			||||||
 | 
						LastActive int64  `json:"last_active"`
 | 
				
			||||||
 | 
						Status     int    `json:"status"`
 | 
				
			||||||
 | 
						UA         string `json:"ua"`
 | 
				
			||||||
 | 
						IP         string `json:"ip"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func ListMySessions(c *gin.Context) {
 | 
				
			||||||
 | 
						user := c.MustGet("user").(*model.User)
 | 
				
			||||||
 | 
						sessions, err := db.ListSessionsByUser(user.ID)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							common.ErrorResp(c, err, 500)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						resp := make([]SessionResp, len(sessions))
 | 
				
			||||||
 | 
						for i, s := range sessions {
 | 
				
			||||||
 | 
							resp[i] = SessionResp{
 | 
				
			||||||
 | 
								SessionID:  s.DeviceKey,
 | 
				
			||||||
 | 
								LastActive: s.LastActive,
 | 
				
			||||||
 | 
								Status:     s.Status,
 | 
				
			||||||
 | 
								UA:         s.UserAgent,
 | 
				
			||||||
 | 
								IP:         s.IP,
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						common.SuccessResp(c, resp)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type EvictSessionReq struct {
 | 
				
			||||||
 | 
						SessionID string `json:"session_id"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func EvictMySession(c *gin.Context) {
 | 
				
			||||||
 | 
						var req EvictSessionReq
 | 
				
			||||||
 | 
						if err := c.ShouldBindJSON(&req); err != nil {
 | 
				
			||||||
 | 
							common.ErrorResp(c, err, 400)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						user := c.MustGet("user").(*model.User)
 | 
				
			||||||
 | 
						if _, err := db.GetSession(user.ID, req.SessionID); err != nil {
 | 
				
			||||||
 | 
							common.ErrorResp(c, err, 400)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := db.MarkInactive(req.SessionID); err != nil {
 | 
				
			||||||
 | 
							common.ErrorResp(c, err, 500)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						common.SuccessResp(c)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func ListSessions(c *gin.Context) {
 | 
				
			||||||
 | 
						sessions, err := db.ListSessions()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							common.ErrorResp(c, err, 500)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						resp := make([]SessionResp, len(sessions))
 | 
				
			||||||
 | 
						for i, s := range sessions {
 | 
				
			||||||
 | 
							resp[i] = SessionResp{
 | 
				
			||||||
 | 
								SessionID:  s.DeviceKey,
 | 
				
			||||||
 | 
								UserID:     s.UserID,
 | 
				
			||||||
 | 
								LastActive: s.LastActive,
 | 
				
			||||||
 | 
								Status:     s.Status,
 | 
				
			||||||
 | 
								UA:         s.UserAgent,
 | 
				
			||||||
 | 
								IP:         s.IP,
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						common.SuccessResp(c, resp)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func EvictSession(c *gin.Context) {
 | 
				
			||||||
 | 
						var req EvictSessionReq
 | 
				
			||||||
 | 
						if err := c.ShouldBindJSON(&req); err != nil {
 | 
				
			||||||
 | 
							common.ErrorResp(c, err, 400)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := db.MarkInactive(req.SessionID); err != nil {
 | 
				
			||||||
 | 
							common.ErrorResp(c, err, 500)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						common.SuccessResp(c)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -5,9 +5,11 @@ import (
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/alist-org/alist/v3/internal/conf"
 | 
						"github.com/alist-org/alist/v3/internal/conf"
 | 
				
			||||||
 | 
						"github.com/alist-org/alist/v3/internal/device"
 | 
				
			||||||
	"github.com/alist-org/alist/v3/internal/model"
 | 
						"github.com/alist-org/alist/v3/internal/model"
 | 
				
			||||||
	"github.com/alist-org/alist/v3/internal/op"
 | 
						"github.com/alist-org/alist/v3/internal/op"
 | 
				
			||||||
	"github.com/alist-org/alist/v3/internal/setting"
 | 
						"github.com/alist-org/alist/v3/internal/setting"
 | 
				
			||||||
 | 
						"github.com/alist-org/alist/v3/pkg/utils"
 | 
				
			||||||
	"github.com/alist-org/alist/v3/server/common"
 | 
						"github.com/alist-org/alist/v3/server/common"
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	log "github.com/sirupsen/logrus"
 | 
						log "github.com/sirupsen/logrus"
 | 
				
			||||||
| 
						 | 
					@ -24,7 +26,9 @@ func Auth(c *gin.Context) {
 | 
				
			||||||
			c.Abort()
 | 
								c.Abort()
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		c.Set("user", admin)
 | 
							if !handleSession(c, admin) {
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
		log.Debugf("use admin token: %+v", admin)
 | 
							log.Debugf("use admin token: %+v", admin)
 | 
				
			||||||
		c.Next()
 | 
							c.Next()
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
| 
						 | 
					@ -50,7 +54,9 @@ func Auth(c *gin.Context) {
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			guest.RolesDetail = roles
 | 
								guest.RolesDetail = roles
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		c.Set("user", guest)
 | 
							if !handleSession(c, guest) {
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
		log.Debugf("use empty token: %+v", guest)
 | 
							log.Debugf("use empty token: %+v", guest)
 | 
				
			||||||
		c.Next()
 | 
							c.Next()
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
| 
						 | 
					@ -87,11 +93,29 @@ func Auth(c *gin.Context) {
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		user.RolesDetail = roles
 | 
							user.RolesDetail = roles
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.Set("user", user)
 | 
						if !handleSession(c, user) {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	log.Debugf("use login token: %+v", user)
 | 
						log.Debugf("use login token: %+v", user)
 | 
				
			||||||
	c.Next()
 | 
						c.Next()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func handleSession(c *gin.Context, user *model.User) bool {
 | 
				
			||||||
 | 
						clientID := c.GetHeader("Client-Id")
 | 
				
			||||||
 | 
						if clientID == "" {
 | 
				
			||||||
 | 
							clientID = c.Query("client_id")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s-%s-%s", user.ID, c.Request.UserAgent(), c.ClientIP(), clientID))
 | 
				
			||||||
 | 
						if err := device.Handle(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil {
 | 
				
			||||||
 | 
							common.ErrorResp(c, err, 403)
 | 
				
			||||||
 | 
							c.Abort()
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						c.Set("device_key", key)
 | 
				
			||||||
 | 
						c.Set("user", user)
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func Authn(c *gin.Context) {
 | 
					func Authn(c *gin.Context) {
 | 
				
			||||||
	token := c.GetHeader("Authorization")
 | 
						token := c.GetHeader("Authorization")
 | 
				
			||||||
	if subtle.ConstantTimeCompare([]byte(token), []byte(setting.GetStr(conf.Token))) == 1 {
 | 
						if subtle.ConstantTimeCompare([]byte(token), []byte(setting.GetStr(conf.Token))) == 1 {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,26 @@
 | 
				
			||||||
 | 
					package middlewares
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/alist-org/alist/v3/internal/device"
 | 
				
			||||||
 | 
						"github.com/alist-org/alist/v3/internal/model"
 | 
				
			||||||
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SessionRefresh updates session's last_active after successful requests.
 | 
				
			||||||
 | 
					func SessionRefresh(c *gin.Context) {
 | 
				
			||||||
 | 
						c.Next()
 | 
				
			||||||
 | 
						if c.Writer.Status() >= 400 {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if inactive, ok := c.Get("session_inactive"); ok {
 | 
				
			||||||
 | 
							if b, ok := inactive.(bool); ok && b {
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						userVal, uok := c.Get("user")
 | 
				
			||||||
 | 
						keyVal, kok := c.Get("device_key")
 | 
				
			||||||
 | 
						if uok && kok {
 | 
				
			||||||
 | 
							user := userVal.(*model.User)
 | 
				
			||||||
 | 
							device.Refresh(user.ID, keyVal.(string))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -22,6 +22,7 @@ func Init(e *gin.Engine) {
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	Cors(e)
 | 
						Cors(e)
 | 
				
			||||||
 | 
						e.Use(middlewares.SessionRefresh)
 | 
				
			||||||
	g := e.Group(conf.URL.Path)
 | 
						g := e.Group(conf.URL.Path)
 | 
				
			||||||
	if conf.Conf.Scheme.HttpPort != -1 && conf.Conf.Scheme.HttpsPort != -1 && conf.Conf.Scheme.ForceHttps {
 | 
						if conf.Conf.Scheme.HttpPort != -1 && conf.Conf.Scheme.HttpsPort != -1 && conf.Conf.Scheme.ForceHttps {
 | 
				
			||||||
		e.Use(middlewares.ForceHttps)
 | 
							e.Use(middlewares.ForceHttps)
 | 
				
			||||||
| 
						 | 
					@ -70,6 +71,8 @@ func Init(e *gin.Engine) {
 | 
				
			||||||
	auth.POST("/auth/2fa/generate", handles.Generate2FA)
 | 
						auth.POST("/auth/2fa/generate", handles.Generate2FA)
 | 
				
			||||||
	auth.POST("/auth/2fa/verify", handles.Verify2FA)
 | 
						auth.POST("/auth/2fa/verify", handles.Verify2FA)
 | 
				
			||||||
	auth.GET("/auth/logout", handles.LogOut)
 | 
						auth.GET("/auth/logout", handles.LogOut)
 | 
				
			||||||
 | 
						auth.GET("/me/sessions", handles.ListMySessions)
 | 
				
			||||||
 | 
						auth.POST("/me/sessions/evict", handles.EvictMySession)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// auth
 | 
						// auth
 | 
				
			||||||
	api.GET("/auth/sso", handles.SSOLoginRedirect)
 | 
						api.GET("/auth/sso", handles.SSOLoginRedirect)
 | 
				
			||||||
| 
						 | 
					@ -184,6 +187,10 @@ func admin(g *gin.RouterGroup) {
 | 
				
			||||||
	labelFileBinding.POST("/delete", handles.DelLabelByFileName)
 | 
						labelFileBinding.POST("/delete", handles.DelLabelByFileName)
 | 
				
			||||||
	labelFileBinding.POST("/restore", handles.RestoreLabelFileBinding)
 | 
						labelFileBinding.POST("/restore", handles.RestoreLabelFileBinding)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						session := g.Group("/session")
 | 
				
			||||||
 | 
						session.GET("/list", handles.ListSessions)
 | 
				
			||||||
 | 
						session.POST("/evict", handles.EvictSession)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func _fs(g *gin.RouterGroup) {
 | 
					func _fs(g *gin.RouterGroup) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,6 +3,7 @@ package server
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	"crypto/subtle"
 | 
						"crypto/subtle"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
	"path"
 | 
						"path"
 | 
				
			||||||
| 
						 | 
					@ -12,9 +13,11 @@ import (
 | 
				
			||||||
	"github.com/alist-org/alist/v3/server/middlewares"
 | 
						"github.com/alist-org/alist/v3/server/middlewares"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/alist-org/alist/v3/internal/conf"
 | 
						"github.com/alist-org/alist/v3/internal/conf"
 | 
				
			||||||
 | 
						"github.com/alist-org/alist/v3/internal/device"
 | 
				
			||||||
	"github.com/alist-org/alist/v3/internal/model"
 | 
						"github.com/alist-org/alist/v3/internal/model"
 | 
				
			||||||
	"github.com/alist-org/alist/v3/internal/op"
 | 
						"github.com/alist-org/alist/v3/internal/op"
 | 
				
			||||||
	"github.com/alist-org/alist/v3/internal/setting"
 | 
						"github.com/alist-org/alist/v3/internal/setting"
 | 
				
			||||||
 | 
						"github.com/alist-org/alist/v3/pkg/utils"
 | 
				
			||||||
	"github.com/alist-org/alist/v3/server/common"
 | 
						"github.com/alist-org/alist/v3/server/common"
 | 
				
			||||||
	"github.com/alist-org/alist/v3/server/webdav"
 | 
						"github.com/alist-org/alist/v3/server/webdav"
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
| 
						 | 
					@ -69,6 +72,13 @@ func WebDAVAuth(c *gin.Context) {
 | 
				
			||||||
					c.Abort()
 | 
										c.Abort()
 | 
				
			||||||
					return
 | 
										return
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
									key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", admin.ID, c.ClientIP()))
 | 
				
			||||||
 | 
									if err := device.Handle(admin.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil {
 | 
				
			||||||
 | 
										c.Status(http.StatusForbidden)
 | 
				
			||||||
 | 
										c.Abort()
 | 
				
			||||||
 | 
										return
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									c.Set("device_key", key)
 | 
				
			||||||
				c.Set("user", admin)
 | 
									c.Set("user", admin)
 | 
				
			||||||
				c.Next()
 | 
									c.Next()
 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
| 
						 | 
					@ -146,6 +156,13 @@ func WebDAVAuth(c *gin.Context) {
 | 
				
			||||||
		c.Abort()
 | 
							c.Abort()
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", user.ID, c.ClientIP()))
 | 
				
			||||||
 | 
						if err := device.Handle(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil {
 | 
				
			||||||
 | 
							c.Status(http.StatusForbidden)
 | 
				
			||||||
 | 
							c.Abort()
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						c.Set("device_key", key)
 | 
				
			||||||
	c.Set("user", user)
 | 
						c.Set("user", user)
 | 
				
			||||||
	c.Next()
 | 
						c.Next()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue