From 77d0c78bfd0e7040db459940312372e9a4813b05 Mon Sep 17 00:00:00 2001 From: KirCute_ECT <951206789@qq.com> Date: Wed, 25 Dec 2024 21:15:06 +0800 Subject: [PATCH] feat(sftp-server): public key login (#7668) --- internal/db/db.go | 2 +- internal/db/sshkey.go | 57 ++++++++++++++++++ internal/model/sshkey.go | 28 +++++++++ internal/op/sshkey.go | 48 +++++++++++++++ server/handles/sshkey.go | 124 +++++++++++++++++++++++++++++++++++++++ server/router.go | 5 ++ server/sftp.go | 27 ++++++++- 7 files changed, 289 insertions(+), 2 deletions(-) create mode 100644 internal/db/sshkey.go create mode 100644 internal/model/sshkey.go create mode 100644 internal/op/sshkey.go create mode 100644 server/handles/sshkey.go diff --git a/internal/db/db.go b/internal/db/db.go index 2df58d37..2cd18050 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -12,7 +12,7 @@ var db *gorm.DB func Init(d *gorm.DB) { db = d - err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem)) + err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey)) if err != nil { log.Fatalf("failed migrate database: %s", err.Error()) } diff --git a/internal/db/sshkey.go b/internal/db/sshkey.go new file mode 100644 index 00000000..f51dbfdc --- /dev/null +++ b/internal/db/sshkey.go @@ -0,0 +1,57 @@ +package db + +import ( + "github.com/alist-org/alist/v3/internal/model" + "github.com/pkg/errors" +) + +func GetSSHPublicKeyByUserId(userId uint, pageIndex, pageSize int) (keys []model.SSHPublicKey, count int64, err error) { + keyDB := db.Model(&model.SSHPublicKey{}) + query := model.SSHPublicKey{UserId: userId} + if err := keyDB.Where(query).Count(&count).Error; err != nil { + return nil, 0, errors.Wrapf(err, "failed get user's keys count") + } + if err := keyDB.Where(query).Order(columnName("id")).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&keys).Error; err != nil { + return nil, 0, errors.Wrapf(err, "failed get find user's keys") + } + return keys, count, nil +} + +func GetSSHPublicKeyById(id uint) (*model.SSHPublicKey, error) { + var k model.SSHPublicKey + if err := db.First(&k, id).Error; err != nil { + return nil, errors.Wrapf(err, "failed get old key") + } + return &k, nil +} + +func GetSSHPublicKeyByUserTitle(userId uint, title string) (*model.SSHPublicKey, error) { + key := model.SSHPublicKey{UserId: userId, Title: title} + if err := db.Where(key).First(&key).Error; err != nil { + return nil, errors.Wrapf(err, "failed find key with title of user") + } + return &key, nil +} + +func CreateSSHPublicKey(k *model.SSHPublicKey) error { + return errors.WithStack(db.Create(k).Error) +} + +func UpdateSSHPublicKey(k *model.SSHPublicKey) error { + return errors.WithStack(db.Save(k).Error) +} + +func GetSSHPublicKeys(pageIndex, pageSize int) (keys []model.SSHPublicKey, count int64, err error) { + keyDB := db.Model(&model.SSHPublicKey{}) + if err := keyDB.Count(&count).Error; err != nil { + return nil, 0, errors.Wrapf(err, "failed get keys count") + } + if err := keyDB.Order(columnName("id")).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&keys).Error; err != nil { + return nil, 0, errors.Wrapf(err, "failed get find keys") + } + return keys, count, nil +} + +func DeleteSSHPublicKeyById(id uint) error { + return errors.WithStack(db.Delete(&model.SSHPublicKey{}, id).Error) +} diff --git a/internal/model/sshkey.go b/internal/model/sshkey.go new file mode 100644 index 00000000..6e97c103 --- /dev/null +++ b/internal/model/sshkey.go @@ -0,0 +1,28 @@ +package model + +import ( + "golang.org/x/crypto/ssh" + "time" +) + +type SSHPublicKey struct { + ID uint `json:"id" gorm:"primaryKey"` + UserId uint `json:"-"` + Title string `json:"title"` + Fingerprint string `json:"fingerprint"` + KeyStr string `gorm:"type:text" json:"-"` + AddedTime time.Time `json:"added_time"` + LastUsedTime time.Time `json:"last_used_time"` +} + +func (k *SSHPublicKey) GetKey() (ssh.PublicKey, error) { + pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(k.KeyStr)) + if err != nil { + return nil, err + } + return pubKey, nil +} + +func (k *SSHPublicKey) UpdateLastUsedTime() { + k.LastUsedTime = time.Now() +} diff --git a/internal/op/sshkey.go b/internal/op/sshkey.go new file mode 100644 index 00000000..6ed55658 --- /dev/null +++ b/internal/op/sshkey.go @@ -0,0 +1,48 @@ +package op + +import ( + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/model" + "github.com/pkg/errors" + "golang.org/x/crypto/ssh" + "time" +) + +func CreateSSHPublicKey(k *model.SSHPublicKey) (error, bool) { + _, err := db.GetSSHPublicKeyByUserTitle(k.UserId, k.Title) + if err == nil { + return errors.New("key with the same title already exists"), true + } + pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(k.KeyStr)) + if err != nil { + return err, false + } + k.KeyStr = string(pubKey.Marshal()) + k.Fingerprint = ssh.FingerprintSHA256(pubKey) + k.AddedTime = time.Now() + k.LastUsedTime = k.AddedTime + return db.CreateSSHPublicKey(k), true +} + +func GetSSHPublicKeyByUserId(userId uint, pageIndex, pageSize int) (keys []model.SSHPublicKey, count int64, err error) { + return db.GetSSHPublicKeyByUserId(userId, pageIndex, pageSize) +} + +func GetSSHPublicKeyByIdAndUserId(id uint, userId uint) (*model.SSHPublicKey, error) { + key, err := db.GetSSHPublicKeyById(id) + if err != nil { + return nil, err + } + if key.UserId != userId { + return nil, errors.Wrapf(err, "failed get old key") + } + return key, nil +} + +func UpdateSSHPublicKey(k *model.SSHPublicKey) error { + return db.UpdateSSHPublicKey(k) +} + +func DeleteSSHPublicKeyById(keyId uint) error { + return db.DeleteSSHPublicKeyById(keyId) +} diff --git a/server/handles/sshkey.go b/server/handles/sshkey.go new file mode 100644 index 00000000..c53b46f2 --- /dev/null +++ b/server/handles/sshkey.go @@ -0,0 +1,124 @@ +package handles + +import ( + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" + "strconv" +) + +type SSHKeyAddReq struct { + Title string `json:"title" binding:"required"` + Key string `json:"key" binding:"required"` +} + +func AddMyPublicKey(c *gin.Context) { + userObj, ok := c.Value("user").(*model.User) + if !ok || userObj.IsGuest() { + common.ErrorStrResp(c, "user invalid", 401) + return + } + var req SSHKeyAddReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorStrResp(c, "request invalid", 400) + return + } + if req.Title == "" { + common.ErrorStrResp(c, "request invalid", 400) + return + } + key := &model.SSHPublicKey{ + Title: req.Title, + KeyStr: req.Key, + UserId: userObj.ID, + } + err, parsed := op.CreateSSHPublicKey(key) + if !parsed { + common.ErrorStrResp(c, "provided key invalid", 400) + return + } else if err != nil { + common.ErrorResp(c, err, 500, true) + return + } + common.SuccessResp(c) +} + +func ListMyPublicKey(c *gin.Context) { + userObj, ok := c.Value("user").(*model.User) + if !ok || userObj.IsGuest() { + common.ErrorStrResp(c, "user invalid", 401) + return + } + list(c, userObj) +} + +func DeleteMyPublicKey(c *gin.Context) { + userObj, ok := c.Value("user").(*model.User) + if !ok || userObj.IsGuest() { + common.ErrorStrResp(c, "user invalid", 401) + return + } + keyId, err := strconv.Atoi(c.Query("id")) + if err != nil { + common.ErrorStrResp(c, "id format invalid", 400) + return + } + key, err := op.GetSSHPublicKeyByIdAndUserId(uint(keyId), userObj.ID) + if err != nil { + common.ErrorStrResp(c, "failed to get public key", 404) + return + } + err = op.DeleteSSHPublicKeyById(key.ID) + if err != nil { + common.ErrorResp(c, err, 500, true) + return + } + common.SuccessResp(c) +} + +func ListPublicKeys(c *gin.Context) { + userId, err := strconv.Atoi(c.Query("uid")) + if err != nil { + common.ErrorStrResp(c, "user id format invalid", 400) + return + } + userObj, err := op.GetUserById(uint(userId)) + if err != nil { + common.ErrorStrResp(c, "user invalid", 404) + return + } + list(c, userObj) +} + +func DeletePublicKey(c *gin.Context) { + keyId, err := strconv.Atoi(c.Query("id")) + if err != nil { + common.ErrorStrResp(c, "id format invalid", 400) + return + } + err = op.DeleteSSHPublicKeyById(uint(keyId)) + if err != nil { + common.ErrorResp(c, err, 500, true) + return + } + common.SuccessResp(c) +} + +func list(c *gin.Context, userObj *model.User) { + var req model.PageReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + req.Validate() + keys, total, err := op.GetSSHPublicKeyByUserId(userObj.ID, req.Page, req.PerPage) + if err != nil { + common.ErrorResp(c, err, 500, true) + return + } + common.SuccessResp(c, common.PageResp{ + Content: keys, + Total: total, + }) +} diff --git a/server/router.go b/server/router.go index fffa840e..9ff50365 100644 --- a/server/router.go +++ b/server/router.go @@ -52,6 +52,9 @@ func Init(e *gin.Engine) { api.POST("/auth/login/ldap", handles.LoginLdap) auth.GET("/me", handles.CurrentUser) auth.POST("/me/update", handles.UpdateCurrent) + auth.GET("/me/sshkey/list", handles.ListMyPublicKey) + auth.POST("/me/sshkey/add", handles.AddMyPublicKey) + auth.POST("/me/sshkey/delete", handles.DeleteMyPublicKey) auth.POST("/auth/2fa/generate", handles.Generate2FA) auth.POST("/auth/2fa/verify", handles.Verify2FA) auth.GET("/auth/logout", handles.LogOut) @@ -102,6 +105,8 @@ func admin(g *gin.RouterGroup) { user.POST("/cancel_2fa", handles.Cancel2FAById) user.POST("/delete", handles.DeleteUser) user.POST("/del_cache", handles.DelUserCache) + user.GET("/sshkey/list", handles.ListPublicKeys) + user.POST("/sshkey/delete", handles.DeletePublicKey) storage := g.Group("/storage") storage.GET("/list", handles.ListStorages) diff --git a/server/sftp.go b/server/sftp.go index 3b07d472..d44046a4 100644 --- a/server/sftp.go +++ b/server/sftp.go @@ -12,6 +12,7 @@ import ( "github.com/pkg/errors" "golang.org/x/crypto/ssh" "net/http" + "time" ) type SftpDriver struct { @@ -35,6 +36,7 @@ func (d *SftpDriver) GetConfig() *sftpd.Config { NoClientAuth: true, NoClientAuthCallback: d.NoClientAuth, PasswordCallback: d.PasswordAuth, + PublicKeyCallback: d.PublicKeyAuth, AuthLogCallback: d.AuthLogCallback, BannerCallback: d.GetBanner, } @@ -85,14 +87,37 @@ func (d *SftpDriver) PasswordAuth(conn ssh.ConnMetadata, password []byte) (*ssh. if err != nil { return nil, err } + if userObj.Disabled || !userObj.CanFTPAccess() { + return nil, errors.New("user is not allowed to access via SFTP") + } passHash := model.StaticHash(string(password)) if err = userObj.ValidatePwdStaticHash(passHash); err != nil { return nil, err } + return nil, nil +} + +func (d *SftpDriver) PublicKeyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + userObj, err := op.GetUserByName(conn.User()) + if err != nil { + return nil, err + } if userObj.Disabled || !userObj.CanFTPAccess() { return nil, errors.New("user is not allowed to access via SFTP") } - return nil, nil + keys, _, err := op.GetSSHPublicKeyByUserId(userObj.ID, 1, -1) + if err != nil { + return nil, err + } + marshal := string(key.Marshal()) + for _, sk := range keys { + if marshal == sk.KeyStr { + sk.LastUsedTime = time.Now() + _ = op.UpdateSSHPublicKey(&sk) + return nil, nil + } + } + return nil, errors.New("public key refused") } func (d *SftpDriver) AuthLogCallback(conn ssh.ConnMetadata, method string, err error) {