diff --git a/internal/bootstrap/data/setting.go b/internal/bootstrap/data/setting.go index 17a63af2..bbb633e3 100644 --- a/internal/bootstrap/data/setting.go +++ b/internal/bootstrap/data/setting.go @@ -165,6 +165,9 @@ func InitialSettings() []model.SettingItem { {Key: conf.ForwardDirectLinkParams, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL}, {Key: conf.IgnoreDirectLinkParams, Value: "sign,alist_ts", Type: conf.TypeString, Group: model.GLOBAL}, {Key: conf.WebauthnLoginEnabled, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL, Flag: model.PUBLIC}, + {Key: conf.MaxDevices, Value: "0", Type: conf.TypeNumber, Group: model.GLOBAL}, + {Key: conf.DeviceEvictPolicy, Value: "deny", Type: conf.TypeSelect, Options: "deny,evict_oldest", Group: model.GLOBAL}, + {Key: conf.DeviceSessionTTL, Value: "86400", Type: conf.TypeNumber, Group: model.GLOBAL}, // single settings {Key: conf.Token, Value: token, Type: conf.TypeString, Group: model.SINGLE, Flag: model.PRIVATE}, diff --git a/internal/conf/const.go b/internal/conf/const.go index 0bf0cd67..1a558163 100644 --- a/internal/conf/const.go +++ b/internal/conf/const.go @@ -48,6 +48,9 @@ const ( ForwardDirectLinkParams = "forward_direct_link_params" IgnoreDirectLinkParams = "ignore_direct_link_params" WebauthnLoginEnabled = "webauthn_login_enabled" + MaxDevices = "max_devices" + DeviceEvictPolicy = "device_evict_policy" + DeviceSessionTTL = "device_session_ttl" // index SearchIndex = "search_index" diff --git a/internal/db/db.go b/internal/db/db.go index 0d8ab421..4577059d 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -12,7 +12,7 @@ var db *gorm.DB func Init(d *gorm.DB) { db = d - err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.Role), new(model.Label), new(model.LabelFileBinding), new(model.ObjFile)) + err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.Role), new(model.Label), new(model.LabelFileBinding), new(model.ObjFile), new(model.Session)) if err != nil { log.Fatalf("failed migrate database: %s", err.Error()) } diff --git a/internal/db/session.go b/internal/db/session.go new file mode 100644 index 00000000..8b0f7333 --- /dev/null +++ b/internal/db/session.go @@ -0,0 +1,49 @@ +package db + +import ( + "github.com/alist-org/alist/v3/internal/model" + "github.com/pkg/errors" + "gorm.io/gorm/clause" +) + +func GetSession(userID uint, deviceKey string) (*model.Session, error) { + s := model.Session{UserID: userID, DeviceKey: deviceKey} + if err := db.Where(&s).First(&s).Error; err != nil { + return nil, errors.Wrap(err, "failed find session") + } + return &s, nil +} + +func CreateSession(s *model.Session) error { + return errors.WithStack(db.Create(s).Error) +} + +func UpsertSession(s *model.Session) error { + return errors.WithStack(db.Clauses(clause.OnConflict{UpdateAll: true}).Create(s).Error) +} + +func DeleteSession(userID uint, deviceKey string) error { + return errors.WithStack(db.Where("user_id = ? AND device_key = ?", userID, deviceKey).Delete(&model.Session{}).Error) +} + +func CountSessionsByUser(userID uint) (int64, error) { + var count int64 + err := db.Model(&model.Session{}).Where("user_id = ?", userID).Count(&count).Error + return count, errors.WithStack(err) +} + +func DeleteSessionsBefore(ts int64) error { + return errors.WithStack(db.Where("last_active < ?", ts).Delete(&model.Session{}).Error) +} + +func GetOldestSession(userID uint) (*model.Session, error) { + var s model.Session + if err := db.Where("user_id = ?", userID).Order("last_active ASC").First(&s).Error; err != nil { + return nil, errors.Wrap(err, "failed get oldest session") + } + return &s, nil +} + +func UpdateSessionLastActive(userID uint, deviceKey string, lastActive int64) error { + return errors.WithStack(db.Model(&model.Session{}).Where("user_id = ? AND device_key = ?", userID, deviceKey).Update("last_active", lastActive).Error) +} diff --git a/internal/device/session.go b/internal/device/session.go new file mode 100644 index 00000000..f608e735 --- /dev/null +++ b/internal/device/session.go @@ -0,0 +1,59 @@ +package device + +import ( + "time" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +// Handle verifies device sessions for a user and upserts current session. +func Handle(userID uint, deviceKey string) error { + ttl := setting.GetInt(conf.DeviceSessionTTL, 86400) + if ttl > 0 { + _ = db.DeleteSessionsBefore(time.Now().Unix() - int64(ttl)) + } + + now := time.Now().Unix() + sess, err := db.GetSession(userID, deviceKey) + if err == nil { + sess.LastActive = now + sess.Status = model.SessionActive + return db.UpsertSession(sess) + } + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + max := setting.GetInt(conf.MaxDevices, 0) + if max > 0 { + count, err := db.CountSessionsByUser(userID) + if err != nil { + return err + } + if count >= int64(max) { + policy := setting.GetStr(conf.DeviceEvictPolicy, "deny") + if policy == "evict_oldest" { + oldest, err := db.GetOldestSession(userID) + if err == nil { + _ = db.DeleteSession(userID, oldest.DeviceKey) + } + } else { + return errors.WithStack(errs.TooManyDevices) + } + } + } + + s := &model.Session{UserID: userID, DeviceKey: deviceKey, LastActive: now, Status: model.SessionActive} + return db.CreateSession(s) +} + +// Refresh updates last_active for the session. +func Refresh(userID uint, deviceKey string) { + _ = db.UpdateSessionLastActive(userID, deviceKey, time.Now().Unix()) +} diff --git a/internal/errs/device.go b/internal/errs/device.go new file mode 100644 index 00000000..9d3bd744 --- /dev/null +++ b/internal/errs/device.go @@ -0,0 +1,7 @@ +package errs + +import "errors" + +var ( + TooManyDevices = errors.New("too many active devices") +) diff --git a/internal/model/session.go b/internal/model/session.go new file mode 100644 index 00000000..de89845a --- /dev/null +++ b/internal/model/session.go @@ -0,0 +1,14 @@ +package model + +// Session represents a device session of a user. +type Session struct { + UserID uint `json:"user_id" gorm:"index"` + DeviceKey string `json:"device_key" gorm:"primaryKey;size:64"` + LastActive int64 `json:"last_active"` + Status int `json:"status"` +} + +const ( + SessionActive = iota + SessionInactive +) diff --git a/server/middlewares/auth.go b/server/middlewares/auth.go index c0743c9c..7e93c950 100644 --- a/server/middlewares/auth.go +++ b/server/middlewares/auth.go @@ -5,9 +5,11 @@ import ( "fmt" "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/device" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/common" "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" @@ -24,7 +26,9 @@ func Auth(c *gin.Context) { c.Abort() return } - c.Set("user", admin) + if !handleSession(c, admin) { + return + } log.Debugf("use admin token: %+v", admin) c.Next() return @@ -50,7 +54,9 @@ func Auth(c *gin.Context) { } guest.RolesDetail = roles } - c.Set("user", guest) + if !handleSession(c, guest) { + return + } log.Debugf("use empty token: %+v", guest) c.Next() return @@ -87,11 +93,29 @@ func Auth(c *gin.Context) { } user.RolesDetail = roles } - c.Set("user", user) + if !handleSession(c, user) { + return + } log.Debugf("use login token: %+v", user) c.Next() } +func handleSession(c *gin.Context, user *model.User) bool { + clientID := c.GetHeader("Client-Id") + if clientID == "" { + clientID = c.Query("client_id") + } + key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s-%s-%s", user.ID, c.Request.UserAgent(), c.ClientIP(), clientID)) + if err := device.Handle(user.ID, key); err != nil { + common.ErrorResp(c, err, 403) + c.Abort() + return false + } + c.Set("device_key", key) + c.Set("user", user) + return true +} + func Authn(c *gin.Context) { token := c.GetHeader("Authorization") if subtle.ConstantTimeCompare([]byte(token), []byte(setting.GetStr(conf.Token))) == 1 { diff --git a/server/middlewares/session_refresh.go b/server/middlewares/session_refresh.go new file mode 100644 index 00000000..cedbc4ab --- /dev/null +++ b/server/middlewares/session_refresh.go @@ -0,0 +1,21 @@ +package middlewares + +import ( + "github.com/alist-org/alist/v3/internal/device" + "github.com/alist-org/alist/v3/internal/model" + "github.com/gin-gonic/gin" +) + +// SessionRefresh updates session's last_active after successful requests. +func SessionRefresh(c *gin.Context) { + c.Next() + if c.Writer.Status() >= 400 { + return + } + userVal, uok := c.Get("user") + keyVal, kok := c.Get("device_key") + if uok && kok { + user := userVal.(*model.User) + device.Refresh(user.ID, keyVal.(string)) + } +} diff --git a/server/router.go b/server/router.go index e8902f7a..bb239ed0 100644 --- a/server/router.go +++ b/server/router.go @@ -22,6 +22,7 @@ func Init(e *gin.Engine) { }) } Cors(e) + e.Use(middlewares.SessionRefresh) g := e.Group(conf.URL.Path) if conf.Conf.Scheme.HttpPort != -1 && conf.Conf.Scheme.HttpsPort != -1 && conf.Conf.Scheme.ForceHttps { e.Use(middlewares.ForceHttps) diff --git a/server/webdav.go b/server/webdav.go index 582c469d..932c039c 100644 --- a/server/webdav.go +++ b/server/webdav.go @@ -3,6 +3,7 @@ package server import ( "context" "crypto/subtle" + "fmt" "net/http" "net/url" "path" @@ -12,9 +13,11 @@ import ( "github.com/alist-org/alist/v3/server/middlewares" "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/device" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/common" "github.com/alist-org/alist/v3/server/webdav" "github.com/gin-gonic/gin" @@ -69,6 +72,13 @@ func WebDAVAuth(c *gin.Context) { c.Abort() return } + key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", admin.ID, c.ClientIP())) + if err := device.Handle(admin.ID, key); err != nil { + c.Status(http.StatusForbidden) + c.Abort() + return + } + c.Set("device_key", key) c.Set("user", admin) c.Next() return @@ -146,6 +156,13 @@ func WebDAVAuth(c *gin.Context) { c.Abort() return } + key := utils.GetMD5EncodeStr(fmt.Sprintf("%d-%s", user.ID, c.ClientIP())) + if err := device.Handle(user.ID, key); err != nil { + c.Status(http.StatusForbidden) + c.Abort() + return + } + c.Set("device_key", key) c.Set("user", user) c.Next() }