diff --git a/internal/db/session.go b/internal/db/session.go index 35c778c3..757809ba 100644 --- a/internal/db/session.go +++ b/internal/db/session.go @@ -67,3 +67,15 @@ func ListSessions() ([]model.Session, error) { func MarkInactive(sessionID string) error { return errors.WithStack(db.Model(&model.Session{}).Where("device_key = ?", sessionID).Update("status", model.SessionInactive).Error) } + +func DeleteInactiveSessions(userID *uint) error { + query := db.Where("status = ?", model.SessionInactive) + if userID != nil { + query = query.Where("user_id = ?", *userID) + } + return errors.WithStack(query.Delete(&model.Session{}).Error) +} + +func DeleteSessionByID(sessionID string) error { + return errors.WithStack(db.Where("device_key = ?", sessionID).Delete(&model.Session{}).Error) +} diff --git a/server/handles/session.go b/server/handles/session.go index 886be66a..6f905de4 100644 --- a/server/handles/session.go +++ b/server/handles/session.go @@ -40,6 +40,11 @@ type EvictSessionReq struct { SessionID string `json:"session_id"` } +type CleanSessionsReq struct { + UserID *uint `json:"user_id"` + SessionID string `json:"session_id"` +} + func EvictMySession(c *gin.Context) { var req EvictSessionReq if err := c.ShouldBindJSON(&req); err != nil { @@ -90,3 +95,32 @@ func EvictSession(c *gin.Context) { } common.SuccessResp(c) } + +func CleanSessions(c *gin.Context) { + var req CleanSessionsReq + if err := c.ShouldBindJSON(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + if req.SessionID != "" { + if err := db.DeleteSessionByID(req.SessionID); err != nil { + common.ErrorResp(c, err, 500) + return + } + common.SuccessResp(c) + return + } + if req.UserID != nil { + if err := db.DeleteInactiveSessions(req.UserID); err != nil { + common.ErrorResp(c, err, 500) + return + } + common.SuccessResp(c) + return + } + if err := db.DeleteInactiveSessions(nil); err != nil { + common.ErrorResp(c, err, 500) + return + } + common.SuccessResp(c) +} diff --git a/server/router.go b/server/router.go index 4d79c1fd..bbcff010 100644 --- a/server/router.go +++ b/server/router.go @@ -190,6 +190,7 @@ func admin(g *gin.RouterGroup) { session := g.Group("/session") session.GET("/list", handles.ListSessions) session.POST("/evict", handles.EvictSession) + session.POST("/clean", handles.CleanSessions) }