diff --git a/internal/db/session.go b/internal/db/session.go index ef35d3bd..8db9fa69 100644 --- a/internal/db/session.go +++ b/internal/db/session.go @@ -8,7 +8,7 @@ import ( func GetSession(userID uint, deviceKey string) (*model.Session, error) { s := model.Session{UserID: userID, DeviceKey: deviceKey} - if err := db.Where(&s).First(&s).Error; err != nil { + if err := db.Select("user_id, device_key, last_active, status, user_agent, ip").Where(&s).First(&s).Error; err != nil { return nil, errors.Wrap(err, "failed find session") } return &s, nil @@ -50,13 +50,13 @@ func UpdateSessionLastActive(userID uint, deviceKey string, lastActive int64) er func ListSessionsByUser(userID uint) ([]model.Session, error) { var sessions []model.Session - err := db.Where("user_id = ? AND status = ?", userID, model.SessionActive).Find(&sessions).Error + err := db.Select("user_id, device_key, last_active, status, user_agent, ip").Where("user_id = ? AND status = ?", userID, model.SessionActive).Find(&sessions).Error return sessions, errors.WithStack(err) } func ListSessions() ([]model.Session, error) { var sessions []model.Session - err := db.Where("status = ?", model.SessionActive).Find(&sessions).Error + err := db.Select("user_id, device_key, last_active, status, user_agent, ip").Where("status = ?", model.SessionActive).Find(&sessions).Error return sessions, errors.WithStack(err) } diff --git a/internal/device/session.go b/internal/device/session.go index d4f0172c..d407c858 100644 --- a/internal/device/session.go +++ b/internal/device/session.go @@ -8,17 +8,20 @@ import ( "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/utils" "github.com/pkg/errors" "gorm.io/gorm" ) // Handle verifies device sessions for a user and upserts current session. -func Handle(userID uint, deviceKey string) error { +func Handle(userID uint, deviceKey, ua, ip string) error { ttl := setting.GetInt(conf.DeviceSessionTTL, 86400) if ttl > 0 { _ = db.DeleteSessionsBefore(time.Now().Unix() - int64(ttl)) } + ip = utils.MaskIP(ip) + now := time.Now().Unix() sess, err := db.GetSession(userID, deviceKey) if err == nil { @@ -27,6 +30,8 @@ func Handle(userID uint, deviceKey string) error { } sess.LastActive = now sess.Status = model.SessionActive + sess.UserAgent = ua + sess.IP = ip return db.UpsertSession(sess) } if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { @@ -52,7 +57,7 @@ func Handle(userID uint, deviceKey string) error { } } - s := &model.Session{UserID: userID, DeviceKey: deviceKey, LastActive: now, Status: model.SessionActive} + s := &model.Session{UserID: userID, DeviceKey: deviceKey, UserAgent: ua, IP: ip, LastActive: now, Status: model.SessionActive} return db.CreateSession(s) } diff --git a/internal/model/session.go b/internal/model/session.go index de89845a..3cb6d0da 100644 --- a/internal/model/session.go +++ b/internal/model/session.go @@ -4,6 +4,8 @@ package model type Session struct { UserID uint `json:"user_id" gorm:"index"` DeviceKey string `json:"device_key" gorm:"primaryKey;size:64"` + UserAgent string `json:"user_agent" gorm:"size:255"` + IP string `json:"ip" gorm:"size:64"` LastActive int64 `json:"last_active"` Status int `json:"status"` } diff --git a/pkg/utils/mask.go b/pkg/utils/mask.go new file mode 100644 index 00000000..1513ad40 --- /dev/null +++ b/pkg/utils/mask.go @@ -0,0 +1,30 @@ +package utils + +import "strings" + +// MaskIP anonymizes middle segments of an IP address. +func MaskIP(ip string) string { + if ip == "" { + return "" + } + if strings.Contains(ip, ":") { + parts := strings.Split(ip, ":") + if len(parts) > 2 { + for i := 1; i < len(parts)-1; i++ { + if parts[i] != "" { + parts[i] = "*" + } + } + return strings.Join(parts, ":") + } + return ip + } + parts := strings.Split(ip, ".") + if len(parts) == 4 { + for i := 1; i < len(parts)-1; i++ { + parts[i] = "*" + } + return strings.Join(parts, ".") + } + return ip +} diff --git a/server/handles/session.go b/server/handles/session.go index c3dd833e..886be66a 100644 --- a/server/handles/session.go +++ b/server/handles/session.go @@ -7,6 +7,15 @@ import ( "github.com/gin-gonic/gin" ) +type SessionResp struct { + SessionID string `json:"session_id"` + UserID uint `json:"user_id,omitempty"` + LastActive int64 `json:"last_active"` + Status int `json:"status"` + UA string `json:"ua"` + IP string `json:"ip"` +} + func ListMySessions(c *gin.Context) { user := c.MustGet("user").(*model.User) sessions, err := db.ListSessionsByUser(user.ID) @@ -14,7 +23,17 @@ func ListMySessions(c *gin.Context) { common.ErrorResp(c, err, 500) return } - common.SuccessResp(c, sessions) + resp := make([]SessionResp, len(sessions)) + for i, s := range sessions { + resp[i] = SessionResp{ + SessionID: s.DeviceKey, + LastActive: s.LastActive, + Status: s.Status, + UA: s.UserAgent, + IP: s.IP, + } + } + common.SuccessResp(c, resp) } type EvictSessionReq struct { @@ -45,7 +64,18 @@ func ListSessions(c *gin.Context) { common.ErrorResp(c, err, 500) return } - common.SuccessResp(c, sessions) + resp := make([]SessionResp, len(sessions)) + for i, s := range sessions { + resp[i] = SessionResp{ + SessionID: s.DeviceKey, + UserID: s.UserID, + LastActive: s.LastActive, + Status: s.Status, + UA: s.UserAgent, + IP: s.IP, + } + } + common.SuccessResp(c, resp) } func EvictSession(c *gin.Context) { diff --git a/server/middlewares/auth.go b/server/middlewares/auth.go index 7e93c950..72eaefe6 100644 --- a/server/middlewares/auth.go +++ b/server/middlewares/auth.go @@ -106,7 +106,7 @@ func handleSession(c *gin.Context, user *model.User) bool { clientID = c.Query("client_id") } key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s-%s-%s", user.ID, c.Request.UserAgent(), c.ClientIP(), clientID)) - if err := device.Handle(user.ID, key); err != nil { + if err := device.Handle(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil { common.ErrorResp(c, err, 403) c.Abort() return false diff --git a/server/webdav.go b/server/webdav.go index 932c039c..e0980139 100644 --- a/server/webdav.go +++ b/server/webdav.go @@ -73,7 +73,7 @@ func WebDAVAuth(c *gin.Context) { return } key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", admin.ID, c.ClientIP())) - if err := device.Handle(admin.ID, key); err != nil { + if err := device.Handle(admin.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil { c.Status(http.StatusForbidden) c.Abort() return @@ -157,7 +157,7 @@ func WebDAVAuth(c *gin.Context) { return } key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", user.ID, c.ClientIP())) - if err := device.Handle(user.ID, key); err != nil { + if err := device.Handle(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil { c.Status(http.StatusForbidden) c.Abort() return