From 77d0c78bfd0e7040db459940312372e9a4813b05 Mon Sep 17 00:00:00 2001
From: KirCute_ECT <951206789@qq.com>
Date: Wed, 25 Dec 2024 21:15:06 +0800
Subject: [PATCH] feat(sftp-server): public key login (#7668)

---
 internal/db/db.go        |   2 +-
 internal/db/sshkey.go    |  57 ++++++++++++++++++
 internal/model/sshkey.go |  28 +++++++++
 internal/op/sshkey.go    |  48 +++++++++++++++
 server/handles/sshkey.go | 124 +++++++++++++++++++++++++++++++++++++++
 server/router.go         |   5 ++
 server/sftp.go           |  27 ++++++++-
 7 files changed, 289 insertions(+), 2 deletions(-)
 create mode 100644 internal/db/sshkey.go
 create mode 100644 internal/model/sshkey.go
 create mode 100644 internal/op/sshkey.go
 create mode 100644 server/handles/sshkey.go

diff --git a/internal/db/db.go b/internal/db/db.go
index 2df58d37..2cd18050 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -12,7 +12,7 @@ var db *gorm.DB
 
 func Init(d *gorm.DB) {
 	db = d
-	err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem))
+	err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey))
 	if err != nil {
 		log.Fatalf("failed migrate database: %s", err.Error())
 	}
diff --git a/internal/db/sshkey.go b/internal/db/sshkey.go
new file mode 100644
index 00000000..f51dbfdc
--- /dev/null
+++ b/internal/db/sshkey.go
@@ -0,0 +1,57 @@
+package db
+
+import (
+	"github.com/alist-org/alist/v3/internal/model"
+	"github.com/pkg/errors"
+)
+
+func GetSSHPublicKeyByUserId(userId uint, pageIndex, pageSize int) (keys []model.SSHPublicKey, count int64, err error) {
+	keyDB := db.Model(&model.SSHPublicKey{})
+	query := model.SSHPublicKey{UserId: userId}
+	if err := keyDB.Where(query).Count(&count).Error; err != nil {
+		return nil, 0, errors.Wrapf(err, "failed get user's keys count")
+	}
+	if err := keyDB.Where(query).Order(columnName("id")).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&keys).Error; err != nil {
+		return nil, 0, errors.Wrapf(err, "failed get find user's keys")
+	}
+	return keys, count, nil
+}
+
+func GetSSHPublicKeyById(id uint) (*model.SSHPublicKey, error) {
+	var k model.SSHPublicKey
+	if err := db.First(&k, id).Error; err != nil {
+		return nil, errors.Wrapf(err, "failed get old key")
+	}
+	return &k, nil
+}
+
+func GetSSHPublicKeyByUserTitle(userId uint, title string) (*model.SSHPublicKey, error) {
+	key := model.SSHPublicKey{UserId: userId, Title: title}
+	if err := db.Where(key).First(&key).Error; err != nil {
+		return nil, errors.Wrapf(err, "failed find key with title of user")
+	}
+	return &key, nil
+}
+
+func CreateSSHPublicKey(k *model.SSHPublicKey) error {
+	return errors.WithStack(db.Create(k).Error)
+}
+
+func UpdateSSHPublicKey(k *model.SSHPublicKey) error {
+	return errors.WithStack(db.Save(k).Error)
+}
+
+func GetSSHPublicKeys(pageIndex, pageSize int) (keys []model.SSHPublicKey, count int64, err error) {
+	keyDB := db.Model(&model.SSHPublicKey{})
+	if err := keyDB.Count(&count).Error; err != nil {
+		return nil, 0, errors.Wrapf(err, "failed get keys count")
+	}
+	if err := keyDB.Order(columnName("id")).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&keys).Error; err != nil {
+		return nil, 0, errors.Wrapf(err, "failed get find keys")
+	}
+	return keys, count, nil
+}
+
+func DeleteSSHPublicKeyById(id uint) error {
+	return errors.WithStack(db.Delete(&model.SSHPublicKey{}, id).Error)
+}
diff --git a/internal/model/sshkey.go b/internal/model/sshkey.go
new file mode 100644
index 00000000..6e97c103
--- /dev/null
+++ b/internal/model/sshkey.go
@@ -0,0 +1,28 @@
+package model
+
+import (
+	"golang.org/x/crypto/ssh"
+	"time"
+)
+
+type SSHPublicKey struct {
+	ID           uint      `json:"id" gorm:"primaryKey"`
+	UserId       uint      `json:"-"`
+	Title        string    `json:"title"`
+	Fingerprint  string    `json:"fingerprint"`
+	KeyStr       string    `gorm:"type:text" json:"-"`
+	AddedTime    time.Time `json:"added_time"`
+	LastUsedTime time.Time `json:"last_used_time"`
+}
+
+func (k *SSHPublicKey) GetKey() (ssh.PublicKey, error) {
+	pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(k.KeyStr))
+	if err != nil {
+		return nil, err
+	}
+	return pubKey, nil
+}
+
+func (k *SSHPublicKey) UpdateLastUsedTime() {
+	k.LastUsedTime = time.Now()
+}
diff --git a/internal/op/sshkey.go b/internal/op/sshkey.go
new file mode 100644
index 00000000..6ed55658
--- /dev/null
+++ b/internal/op/sshkey.go
@@ -0,0 +1,48 @@
+package op
+
+import (
+	"github.com/alist-org/alist/v3/internal/db"
+	"github.com/alist-org/alist/v3/internal/model"
+	"github.com/pkg/errors"
+	"golang.org/x/crypto/ssh"
+	"time"
+)
+
+func CreateSSHPublicKey(k *model.SSHPublicKey) (error, bool) {
+	_, err := db.GetSSHPublicKeyByUserTitle(k.UserId, k.Title)
+	if err == nil {
+		return errors.New("key with the same title already exists"), true
+	}
+	pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(k.KeyStr))
+	if err != nil {
+		return err, false
+	}
+	k.KeyStr = string(pubKey.Marshal())
+	k.Fingerprint = ssh.FingerprintSHA256(pubKey)
+	k.AddedTime = time.Now()
+	k.LastUsedTime = k.AddedTime
+	return db.CreateSSHPublicKey(k), true
+}
+
+func GetSSHPublicKeyByUserId(userId uint, pageIndex, pageSize int) (keys []model.SSHPublicKey, count int64, err error) {
+	return db.GetSSHPublicKeyByUserId(userId, pageIndex, pageSize)
+}
+
+func GetSSHPublicKeyByIdAndUserId(id uint, userId uint) (*model.SSHPublicKey, error) {
+	key, err := db.GetSSHPublicKeyById(id)
+	if err != nil {
+		return nil, err
+	}
+	if key.UserId != userId {
+		return nil, errors.Wrapf(err, "failed get old key")
+	}
+	return key, nil
+}
+
+func UpdateSSHPublicKey(k *model.SSHPublicKey) error {
+	return db.UpdateSSHPublicKey(k)
+}
+
+func DeleteSSHPublicKeyById(keyId uint) error {
+	return db.DeleteSSHPublicKeyById(keyId)
+}
diff --git a/server/handles/sshkey.go b/server/handles/sshkey.go
new file mode 100644
index 00000000..c53b46f2
--- /dev/null
+++ b/server/handles/sshkey.go
@@ -0,0 +1,124 @@
+package handles
+
+import (
+	"github.com/alist-org/alist/v3/internal/model"
+	"github.com/alist-org/alist/v3/internal/op"
+	"github.com/alist-org/alist/v3/server/common"
+	"github.com/gin-gonic/gin"
+	"strconv"
+)
+
+type SSHKeyAddReq struct {
+	Title string `json:"title" binding:"required"`
+	Key   string `json:"key" binding:"required"`
+}
+
+func AddMyPublicKey(c *gin.Context) {
+	userObj, ok := c.Value("user").(*model.User)
+	if !ok || userObj.IsGuest() {
+		common.ErrorStrResp(c, "user invalid", 401)
+		return
+	}
+	var req SSHKeyAddReq
+	if err := c.ShouldBind(&req); err != nil {
+		common.ErrorStrResp(c, "request invalid", 400)
+		return
+	}
+	if req.Title == "" {
+		common.ErrorStrResp(c, "request invalid", 400)
+		return
+	}
+	key := &model.SSHPublicKey{
+		Title:  req.Title,
+		KeyStr: req.Key,
+		UserId: userObj.ID,
+	}
+	err, parsed := op.CreateSSHPublicKey(key)
+	if !parsed {
+		common.ErrorStrResp(c, "provided key invalid", 400)
+		return
+	} else if err != nil {
+		common.ErrorResp(c, err, 500, true)
+		return
+	}
+	common.SuccessResp(c)
+}
+
+func ListMyPublicKey(c *gin.Context) {
+	userObj, ok := c.Value("user").(*model.User)
+	if !ok || userObj.IsGuest() {
+		common.ErrorStrResp(c, "user invalid", 401)
+		return
+	}
+	list(c, userObj)
+}
+
+func DeleteMyPublicKey(c *gin.Context) {
+	userObj, ok := c.Value("user").(*model.User)
+	if !ok || userObj.IsGuest() {
+		common.ErrorStrResp(c, "user invalid", 401)
+		return
+	}
+	keyId, err := strconv.Atoi(c.Query("id"))
+	if err != nil {
+		common.ErrorStrResp(c, "id format invalid", 400)
+		return
+	}
+	key, err := op.GetSSHPublicKeyByIdAndUserId(uint(keyId), userObj.ID)
+	if err != nil {
+		common.ErrorStrResp(c, "failed to get public key", 404)
+		return
+	}
+	err = op.DeleteSSHPublicKeyById(key.ID)
+	if err != nil {
+		common.ErrorResp(c, err, 500, true)
+		return
+	}
+	common.SuccessResp(c)
+}
+
+func ListPublicKeys(c *gin.Context) {
+	userId, err := strconv.Atoi(c.Query("uid"))
+	if err != nil {
+		common.ErrorStrResp(c, "user id format invalid", 400)
+		return
+	}
+	userObj, err := op.GetUserById(uint(userId))
+	if err != nil {
+		common.ErrorStrResp(c, "user invalid", 404)
+		return
+	}
+	list(c, userObj)
+}
+
+func DeletePublicKey(c *gin.Context) {
+	keyId, err := strconv.Atoi(c.Query("id"))
+	if err != nil {
+		common.ErrorStrResp(c, "id format invalid", 400)
+		return
+	}
+	err = op.DeleteSSHPublicKeyById(uint(keyId))
+	if err != nil {
+		common.ErrorResp(c, err, 500, true)
+		return
+	}
+	common.SuccessResp(c)
+}
+
+func list(c *gin.Context, userObj *model.User) {
+	var req model.PageReq
+	if err := c.ShouldBind(&req); err != nil {
+		common.ErrorResp(c, err, 400)
+		return
+	}
+	req.Validate()
+	keys, total, err := op.GetSSHPublicKeyByUserId(userObj.ID, req.Page, req.PerPage)
+	if err != nil {
+		common.ErrorResp(c, err, 500, true)
+		return
+	}
+	common.SuccessResp(c, common.PageResp{
+		Content: keys,
+		Total:   total,
+	})
+}
diff --git a/server/router.go b/server/router.go
index fffa840e..9ff50365 100644
--- a/server/router.go
+++ b/server/router.go
@@ -52,6 +52,9 @@ func Init(e *gin.Engine) {
 	api.POST("/auth/login/ldap", handles.LoginLdap)
 	auth.GET("/me", handles.CurrentUser)
 	auth.POST("/me/update", handles.UpdateCurrent)
+	auth.GET("/me/sshkey/list", handles.ListMyPublicKey)
+	auth.POST("/me/sshkey/add", handles.AddMyPublicKey)
+	auth.POST("/me/sshkey/delete", handles.DeleteMyPublicKey)
 	auth.POST("/auth/2fa/generate", handles.Generate2FA)
 	auth.POST("/auth/2fa/verify", handles.Verify2FA)
 	auth.GET("/auth/logout", handles.LogOut)
@@ -102,6 +105,8 @@ func admin(g *gin.RouterGroup) {
 	user.POST("/cancel_2fa", handles.Cancel2FAById)
 	user.POST("/delete", handles.DeleteUser)
 	user.POST("/del_cache", handles.DelUserCache)
+	user.GET("/sshkey/list", handles.ListPublicKeys)
+	user.POST("/sshkey/delete", handles.DeletePublicKey)
 
 	storage := g.Group("/storage")
 	storage.GET("/list", handles.ListStorages)
diff --git a/server/sftp.go b/server/sftp.go
index 3b07d472..d44046a4 100644
--- a/server/sftp.go
+++ b/server/sftp.go
@@ -12,6 +12,7 @@ import (
 	"github.com/pkg/errors"
 	"golang.org/x/crypto/ssh"
 	"net/http"
+	"time"
 )
 
 type SftpDriver struct {
@@ -35,6 +36,7 @@ func (d *SftpDriver) GetConfig() *sftpd.Config {
 		NoClientAuth:         true,
 		NoClientAuthCallback: d.NoClientAuth,
 		PasswordCallback:     d.PasswordAuth,
+		PublicKeyCallback:    d.PublicKeyAuth,
 		AuthLogCallback:      d.AuthLogCallback,
 		BannerCallback:       d.GetBanner,
 	}
@@ -85,14 +87,37 @@ func (d *SftpDriver) PasswordAuth(conn ssh.ConnMetadata, password []byte) (*ssh.
 	if err != nil {
 		return nil, err
 	}
+	if userObj.Disabled || !userObj.CanFTPAccess() {
+		return nil, errors.New("user is not allowed to access via SFTP")
+	}
 	passHash := model.StaticHash(string(password))
 	if err = userObj.ValidatePwdStaticHash(passHash); err != nil {
 		return nil, err
 	}
+	return nil, nil
+}
+
+func (d *SftpDriver) PublicKeyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
+	userObj, err := op.GetUserByName(conn.User())
+	if err != nil {
+		return nil, err
+	}
 	if userObj.Disabled || !userObj.CanFTPAccess() {
 		return nil, errors.New("user is not allowed to access via SFTP")
 	}
-	return nil, nil
+	keys, _, err := op.GetSSHPublicKeyByUserId(userObj.ID, 1, -1)
+	if err != nil {
+		return nil, err
+	}
+	marshal := string(key.Marshal())
+	for _, sk := range keys {
+		if marshal == sk.KeyStr {
+			sk.LastUsedTime = time.Now()
+			_ = op.UpdateSSHPublicKey(&sk)
+			return nil, nil
+		}
+	}
+	return nil, errors.New("public key refused")
 }
 
 func (d *SftpDriver) AuthLogCallback(conn ssh.ConnMetadata, method string, err error) {