feat(auth, session): Added device information processing and session management changes

- Updated device handling logic in `auth.go` to pass user agent and IP information
- Adjusted database queries in `session.go` to optimize session query fields and add `user_agent` and `ip` fields
- Modified the `Handle` method to add `ua` and `ip` parameters to store the user agent and IP address
- Added the `SessionResp` structure to return a session response containing `user_agent` and `ip`
- Updated the `/admin/user/create` and `/webdav` endpoints to pass the user agent and IP address to the device handler
pull/9286/head
okatu-loli 2025-08-23 21:03:30 +08:00
parent 51613cf110
commit f19a9986e0
7 changed files with 77 additions and 10 deletions

View File

@ -8,7 +8,7 @@ import (
func GetSession(userID uint, deviceKey string) (*model.Session, error) {
s := model.Session{UserID: userID, DeviceKey: deviceKey}
if err := db.Where(&s).First(&s).Error; err != nil {
if err := db.Select("user_id, device_key, last_active, status, user_agent, ip").Where(&s).First(&s).Error; err != nil {
return nil, errors.Wrap(err, "failed find session")
}
return &s, nil
@ -50,13 +50,13 @@ func UpdateSessionLastActive(userID uint, deviceKey string, lastActive int64) er
func ListSessionsByUser(userID uint) ([]model.Session, error) {
var sessions []model.Session
err := db.Where("user_id = ? AND status = ?", userID, model.SessionActive).Find(&sessions).Error
err := db.Select("user_id, device_key, last_active, status, user_agent, ip").Where("user_id = ? AND status = ?", userID, model.SessionActive).Find(&sessions).Error
return sessions, errors.WithStack(err)
}
func ListSessions() ([]model.Session, error) {
var sessions []model.Session
err := db.Where("status = ?", model.SessionActive).Find(&sessions).Error
err := db.Select("user_id, device_key, last_active, status, user_agent, ip").Where("status = ?", model.SessionActive).Find(&sessions).Error
return sessions, errors.WithStack(err)
}

View File

@ -8,17 +8,20 @@ import (
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/setting"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/pkg/errors"
"gorm.io/gorm"
)
// Handle verifies device sessions for a user and upserts current session.
func Handle(userID uint, deviceKey string) error {
func Handle(userID uint, deviceKey, ua, ip string) error {
ttl := setting.GetInt(conf.DeviceSessionTTL, 86400)
if ttl > 0 {
_ = db.DeleteSessionsBefore(time.Now().Unix() - int64(ttl))
}
ip = utils.MaskIP(ip)
now := time.Now().Unix()
sess, err := db.GetSession(userID, deviceKey)
if err == nil {
@ -27,6 +30,8 @@ func Handle(userID uint, deviceKey string) error {
}
sess.LastActive = now
sess.Status = model.SessionActive
sess.UserAgent = ua
sess.IP = ip
return db.UpsertSession(sess)
}
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
@ -52,7 +57,7 @@ func Handle(userID uint, deviceKey string) error {
}
}
s := &model.Session{UserID: userID, DeviceKey: deviceKey, LastActive: now, Status: model.SessionActive}
s := &model.Session{UserID: userID, DeviceKey: deviceKey, UserAgent: ua, IP: ip, LastActive: now, Status: model.SessionActive}
return db.CreateSession(s)
}

View File

@ -4,6 +4,8 @@ package model
type Session struct {
UserID uint `json:"user_id" gorm:"index"`
DeviceKey string `json:"device_key" gorm:"primaryKey;size:64"`
UserAgent string `json:"user_agent" gorm:"size:255"`
IP string `json:"ip" gorm:"size:64"`
LastActive int64 `json:"last_active"`
Status int `json:"status"`
}

30
pkg/utils/mask.go Normal file
View File

@ -0,0 +1,30 @@
package utils
import "strings"
// MaskIP anonymizes middle segments of an IP address.
func MaskIP(ip string) string {
if ip == "" {
return ""
}
if strings.Contains(ip, ":") {
parts := strings.Split(ip, ":")
if len(parts) > 2 {
for i := 1; i < len(parts)-1; i++ {
if parts[i] != "" {
parts[i] = "*"
}
}
return strings.Join(parts, ":")
}
return ip
}
parts := strings.Split(ip, ".")
if len(parts) == 4 {
for i := 1; i < len(parts)-1; i++ {
parts[i] = "*"
}
return strings.Join(parts, ".")
}
return ip
}

View File

@ -7,6 +7,15 @@ import (
"github.com/gin-gonic/gin"
)
type SessionResp struct {
SessionID string `json:"session_id"`
UserID uint `json:"user_id,omitempty"`
LastActive int64 `json:"last_active"`
Status int `json:"status"`
UA string `json:"ua"`
IP string `json:"ip"`
}
func ListMySessions(c *gin.Context) {
user := c.MustGet("user").(*model.User)
sessions, err := db.ListSessionsByUser(user.ID)
@ -14,7 +23,17 @@ func ListMySessions(c *gin.Context) {
common.ErrorResp(c, err, 500)
return
}
common.SuccessResp(c, sessions)
resp := make([]SessionResp, len(sessions))
for i, s := range sessions {
resp[i] = SessionResp{
SessionID: s.DeviceKey,
LastActive: s.LastActive,
Status: s.Status,
UA: s.UserAgent,
IP: s.IP,
}
}
common.SuccessResp(c, resp)
}
type EvictSessionReq struct {
@ -45,7 +64,18 @@ func ListSessions(c *gin.Context) {
common.ErrorResp(c, err, 500)
return
}
common.SuccessResp(c, sessions)
resp := make([]SessionResp, len(sessions))
for i, s := range sessions {
resp[i] = SessionResp{
SessionID: s.DeviceKey,
UserID: s.UserID,
LastActive: s.LastActive,
Status: s.Status,
UA: s.UserAgent,
IP: s.IP,
}
}
common.SuccessResp(c, resp)
}
func EvictSession(c *gin.Context) {

View File

@ -106,7 +106,7 @@ func handleSession(c *gin.Context, user *model.User) bool {
clientID = c.Query("client_id")
}
key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s-%s-%s", user.ID, c.Request.UserAgent(), c.ClientIP(), clientID))
if err := device.Handle(user.ID, key); err != nil {
if err := device.Handle(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil {
common.ErrorResp(c, err, 403)
c.Abort()
return false

View File

@ -73,7 +73,7 @@ func WebDAVAuth(c *gin.Context) {
return
}
key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", admin.ID, c.ClientIP()))
if err := device.Handle(admin.ID, key); err != nil {
if err := device.Handle(admin.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil {
c.Status(http.StatusForbidden)
c.Abort()
return
@ -157,7 +157,7 @@ func WebDAVAuth(c *gin.Context) {
return
}
key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", user.ID, c.ClientIP()))
if err := device.Handle(user.ID, key); err != nil {
if err := device.Handle(user.ID, key, c.Request.UserAgent(), c.ClientIP()); err != nil {
c.Status(http.StatusForbidden)
c.Abort()
return