mirror of https://github.com/Xhofe/alist
feat(sftp-server): public key login (#7668)
parent
db5c601cfe
commit
77d0c78bfd
|
@ -12,7 +12,7 @@ var db *gorm.DB
|
||||||
|
|
||||||
func Init(d *gorm.DB) {
|
func Init(d *gorm.DB) {
|
||||||
db = d
|
db = d
|
||||||
err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem))
|
err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed migrate database: %s", err.Error())
|
log.Fatalf("failed migrate database: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/alist-org/alist/v3/internal/model"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetSSHPublicKeyByUserId(userId uint, pageIndex, pageSize int) (keys []model.SSHPublicKey, count int64, err error) {
|
||||||
|
keyDB := db.Model(&model.SSHPublicKey{})
|
||||||
|
query := model.SSHPublicKey{UserId: userId}
|
||||||
|
if err := keyDB.Where(query).Count(&count).Error; err != nil {
|
||||||
|
return nil, 0, errors.Wrapf(err, "failed get user's keys count")
|
||||||
|
}
|
||||||
|
if err := keyDB.Where(query).Order(columnName("id")).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&keys).Error; err != nil {
|
||||||
|
return nil, 0, errors.Wrapf(err, "failed get find user's keys")
|
||||||
|
}
|
||||||
|
return keys, count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSSHPublicKeyById(id uint) (*model.SSHPublicKey, error) {
|
||||||
|
var k model.SSHPublicKey
|
||||||
|
if err := db.First(&k, id).Error; err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "failed get old key")
|
||||||
|
}
|
||||||
|
return &k, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSSHPublicKeyByUserTitle(userId uint, title string) (*model.SSHPublicKey, error) {
|
||||||
|
key := model.SSHPublicKey{UserId: userId, Title: title}
|
||||||
|
if err := db.Where(key).First(&key).Error; err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "failed find key with title of user")
|
||||||
|
}
|
||||||
|
return &key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateSSHPublicKey(k *model.SSHPublicKey) error {
|
||||||
|
return errors.WithStack(db.Create(k).Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdateSSHPublicKey(k *model.SSHPublicKey) error {
|
||||||
|
return errors.WithStack(db.Save(k).Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSSHPublicKeys(pageIndex, pageSize int) (keys []model.SSHPublicKey, count int64, err error) {
|
||||||
|
keyDB := db.Model(&model.SSHPublicKey{})
|
||||||
|
if err := keyDB.Count(&count).Error; err != nil {
|
||||||
|
return nil, 0, errors.Wrapf(err, "failed get keys count")
|
||||||
|
}
|
||||||
|
if err := keyDB.Order(columnName("id")).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&keys).Error; err != nil {
|
||||||
|
return nil, 0, errors.Wrapf(err, "failed get find keys")
|
||||||
|
}
|
||||||
|
return keys, count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteSSHPublicKeyById(id uint) error {
|
||||||
|
return errors.WithStack(db.Delete(&model.SSHPublicKey{}, id).Error)
|
||||||
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SSHPublicKey struct {
|
||||||
|
ID uint `json:"id" gorm:"primaryKey"`
|
||||||
|
UserId uint `json:"-"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Fingerprint string `json:"fingerprint"`
|
||||||
|
KeyStr string `gorm:"type:text" json:"-"`
|
||||||
|
AddedTime time.Time `json:"added_time"`
|
||||||
|
LastUsedTime time.Time `json:"last_used_time"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *SSHPublicKey) GetKey() (ssh.PublicKey, error) {
|
||||||
|
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(k.KeyStr))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return pubKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *SSHPublicKey) UpdateLastUsedTime() {
|
||||||
|
k.LastUsedTime = time.Now()
|
||||||
|
}
|
|
@ -0,0 +1,48 @@
|
||||||
|
package op
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/alist-org/alist/v3/internal/db"
|
||||||
|
"github.com/alist-org/alist/v3/internal/model"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CreateSSHPublicKey(k *model.SSHPublicKey) (error, bool) {
|
||||||
|
_, err := db.GetSSHPublicKeyByUserTitle(k.UserId, k.Title)
|
||||||
|
if err == nil {
|
||||||
|
return errors.New("key with the same title already exists"), true
|
||||||
|
}
|
||||||
|
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(k.KeyStr))
|
||||||
|
if err != nil {
|
||||||
|
return err, false
|
||||||
|
}
|
||||||
|
k.KeyStr = string(pubKey.Marshal())
|
||||||
|
k.Fingerprint = ssh.FingerprintSHA256(pubKey)
|
||||||
|
k.AddedTime = time.Now()
|
||||||
|
k.LastUsedTime = k.AddedTime
|
||||||
|
return db.CreateSSHPublicKey(k), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSSHPublicKeyByUserId(userId uint, pageIndex, pageSize int) (keys []model.SSHPublicKey, count int64, err error) {
|
||||||
|
return db.GetSSHPublicKeyByUserId(userId, pageIndex, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSSHPublicKeyByIdAndUserId(id uint, userId uint) (*model.SSHPublicKey, error) {
|
||||||
|
key, err := db.GetSSHPublicKeyById(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if key.UserId != userId {
|
||||||
|
return nil, errors.Wrapf(err, "failed get old key")
|
||||||
|
}
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdateSSHPublicKey(k *model.SSHPublicKey) error {
|
||||||
|
return db.UpdateSSHPublicKey(k)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteSSHPublicKeyById(keyId uint) error {
|
||||||
|
return db.DeleteSSHPublicKeyById(keyId)
|
||||||
|
}
|
|
@ -0,0 +1,124 @@
|
||||||
|
package handles
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/alist-org/alist/v3/internal/model"
|
||||||
|
"github.com/alist-org/alist/v3/internal/op"
|
||||||
|
"github.com/alist-org/alist/v3/server/common"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SSHKeyAddReq struct {
|
||||||
|
Title string `json:"title" binding:"required"`
|
||||||
|
Key string `json:"key" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func AddMyPublicKey(c *gin.Context) {
|
||||||
|
userObj, ok := c.Value("user").(*model.User)
|
||||||
|
if !ok || userObj.IsGuest() {
|
||||||
|
common.ErrorStrResp(c, "user invalid", 401)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var req SSHKeyAddReq
|
||||||
|
if err := c.ShouldBind(&req); err != nil {
|
||||||
|
common.ErrorStrResp(c, "request invalid", 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Title == "" {
|
||||||
|
common.ErrorStrResp(c, "request invalid", 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key := &model.SSHPublicKey{
|
||||||
|
Title: req.Title,
|
||||||
|
KeyStr: req.Key,
|
||||||
|
UserId: userObj.ID,
|
||||||
|
}
|
||||||
|
err, parsed := op.CreateSSHPublicKey(key)
|
||||||
|
if !parsed {
|
||||||
|
common.ErrorStrResp(c, "provided key invalid", 400)
|
||||||
|
return
|
||||||
|
} else if err != nil {
|
||||||
|
common.ErrorResp(c, err, 500, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.SuccessResp(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListMyPublicKey(c *gin.Context) {
|
||||||
|
userObj, ok := c.Value("user").(*model.User)
|
||||||
|
if !ok || userObj.IsGuest() {
|
||||||
|
common.ErrorStrResp(c, "user invalid", 401)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
list(c, userObj)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteMyPublicKey(c *gin.Context) {
|
||||||
|
userObj, ok := c.Value("user").(*model.User)
|
||||||
|
if !ok || userObj.IsGuest() {
|
||||||
|
common.ErrorStrResp(c, "user invalid", 401)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
keyId, err := strconv.Atoi(c.Query("id"))
|
||||||
|
if err != nil {
|
||||||
|
common.ErrorStrResp(c, "id format invalid", 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key, err := op.GetSSHPublicKeyByIdAndUserId(uint(keyId), userObj.ID)
|
||||||
|
if err != nil {
|
||||||
|
common.ErrorStrResp(c, "failed to get public key", 404)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = op.DeleteSSHPublicKeyById(key.ID)
|
||||||
|
if err != nil {
|
||||||
|
common.ErrorResp(c, err, 500, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.SuccessResp(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListPublicKeys(c *gin.Context) {
|
||||||
|
userId, err := strconv.Atoi(c.Query("uid"))
|
||||||
|
if err != nil {
|
||||||
|
common.ErrorStrResp(c, "user id format invalid", 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userObj, err := op.GetUserById(uint(userId))
|
||||||
|
if err != nil {
|
||||||
|
common.ErrorStrResp(c, "user invalid", 404)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
list(c, userObj)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeletePublicKey(c *gin.Context) {
|
||||||
|
keyId, err := strconv.Atoi(c.Query("id"))
|
||||||
|
if err != nil {
|
||||||
|
common.ErrorStrResp(c, "id format invalid", 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = op.DeleteSSHPublicKeyById(uint(keyId))
|
||||||
|
if err != nil {
|
||||||
|
common.ErrorResp(c, err, 500, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.SuccessResp(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func list(c *gin.Context, userObj *model.User) {
|
||||||
|
var req model.PageReq
|
||||||
|
if err := c.ShouldBind(&req); err != nil {
|
||||||
|
common.ErrorResp(c, err, 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.Validate()
|
||||||
|
keys, total, err := op.GetSSHPublicKeyByUserId(userObj.ID, req.Page, req.PerPage)
|
||||||
|
if err != nil {
|
||||||
|
common.ErrorResp(c, err, 500, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.SuccessResp(c, common.PageResp{
|
||||||
|
Content: keys,
|
||||||
|
Total: total,
|
||||||
|
})
|
||||||
|
}
|
|
@ -52,6 +52,9 @@ func Init(e *gin.Engine) {
|
||||||
api.POST("/auth/login/ldap", handles.LoginLdap)
|
api.POST("/auth/login/ldap", handles.LoginLdap)
|
||||||
auth.GET("/me", handles.CurrentUser)
|
auth.GET("/me", handles.CurrentUser)
|
||||||
auth.POST("/me/update", handles.UpdateCurrent)
|
auth.POST("/me/update", handles.UpdateCurrent)
|
||||||
|
auth.GET("/me/sshkey/list", handles.ListMyPublicKey)
|
||||||
|
auth.POST("/me/sshkey/add", handles.AddMyPublicKey)
|
||||||
|
auth.POST("/me/sshkey/delete", handles.DeleteMyPublicKey)
|
||||||
auth.POST("/auth/2fa/generate", handles.Generate2FA)
|
auth.POST("/auth/2fa/generate", handles.Generate2FA)
|
||||||
auth.POST("/auth/2fa/verify", handles.Verify2FA)
|
auth.POST("/auth/2fa/verify", handles.Verify2FA)
|
||||||
auth.GET("/auth/logout", handles.LogOut)
|
auth.GET("/auth/logout", handles.LogOut)
|
||||||
|
@ -102,6 +105,8 @@ func admin(g *gin.RouterGroup) {
|
||||||
user.POST("/cancel_2fa", handles.Cancel2FAById)
|
user.POST("/cancel_2fa", handles.Cancel2FAById)
|
||||||
user.POST("/delete", handles.DeleteUser)
|
user.POST("/delete", handles.DeleteUser)
|
||||||
user.POST("/del_cache", handles.DelUserCache)
|
user.POST("/del_cache", handles.DelUserCache)
|
||||||
|
user.GET("/sshkey/list", handles.ListPublicKeys)
|
||||||
|
user.POST("/sshkey/delete", handles.DeletePublicKey)
|
||||||
|
|
||||||
storage := g.Group("/storage")
|
storage := g.Group("/storage")
|
||||||
storage.GET("/list", handles.ListStorages)
|
storage.GET("/list", handles.ListStorages)
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SftpDriver struct {
|
type SftpDriver struct {
|
||||||
|
@ -35,6 +36,7 @@ func (d *SftpDriver) GetConfig() *sftpd.Config {
|
||||||
NoClientAuth: true,
|
NoClientAuth: true,
|
||||||
NoClientAuthCallback: d.NoClientAuth,
|
NoClientAuthCallback: d.NoClientAuth,
|
||||||
PasswordCallback: d.PasswordAuth,
|
PasswordCallback: d.PasswordAuth,
|
||||||
|
PublicKeyCallback: d.PublicKeyAuth,
|
||||||
AuthLogCallback: d.AuthLogCallback,
|
AuthLogCallback: d.AuthLogCallback,
|
||||||
BannerCallback: d.GetBanner,
|
BannerCallback: d.GetBanner,
|
||||||
}
|
}
|
||||||
|
@ -85,14 +87,37 @@ func (d *SftpDriver) PasswordAuth(conn ssh.ConnMetadata, password []byte) (*ssh.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if userObj.Disabled || !userObj.CanFTPAccess() {
|
||||||
|
return nil, errors.New("user is not allowed to access via SFTP")
|
||||||
|
}
|
||||||
passHash := model.StaticHash(string(password))
|
passHash := model.StaticHash(string(password))
|
||||||
if err = userObj.ValidatePwdStaticHash(passHash); err != nil {
|
if err = userObj.ValidatePwdStaticHash(passHash); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *SftpDriver) PublicKeyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||||
|
userObj, err := op.GetUserByName(conn.User())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if userObj.Disabled || !userObj.CanFTPAccess() {
|
if userObj.Disabled || !userObj.CanFTPAccess() {
|
||||||
return nil, errors.New("user is not allowed to access via SFTP")
|
return nil, errors.New("user is not allowed to access via SFTP")
|
||||||
}
|
}
|
||||||
|
keys, _, err := op.GetSSHPublicKeyByUserId(userObj.ID, 1, -1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
marshal := string(key.Marshal())
|
||||||
|
for _, sk := range keys {
|
||||||
|
if marshal == sk.KeyStr {
|
||||||
|
sk.LastUsedTime = time.Now()
|
||||||
|
_ = op.UpdateSSHPublicKey(&sk)
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, errors.New("public key refused")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SftpDriver) AuthLogCallback(conn ssh.ConnMetadata, method string, err error) {
|
func (d *SftpDriver) AuthLogCallback(conn ssh.ConnMetadata, method string, err error) {
|
||||||
|
|
Loading…
Reference in New Issue