feat: SSO auto register (close #4692 in #4795)

Co-authored-by: Andy Hsu <i@nn.ci>
pull/4814/head
WintBit 2023-07-20 16:30:30 +08:00 committed by GitHub
parent cace9db12f
commit de8f9e9eee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 190 additions and 140 deletions

View File

@ -154,13 +154,16 @@ func InitialSettings() []model.SettingItem {
// SSO settings // SSO settings
{Key: conf.SSOLoginEnabled, Value: "false", Type: conf.TypeBool, Group: model.SSO, Flag: model.PUBLIC}, {Key: conf.SSOLoginEnabled, Value: "false", Type: conf.TypeBool, Group: model.SSO, Flag: model.PUBLIC},
{Key: conf.SSOLoginplatform, Type: conf.TypeSelect, Options: "Casdoor,Github,Microsoft,Google,Dingtalk,OIDC", Group: model.SSO, Flag: model.PUBLIC}, {Key: conf.SSOLoginPlatform, Type: conf.TypeSelect, Options: "Casdoor,Github,Microsoft,Google,Dingtalk,OIDC", Group: model.SSO, Flag: model.PUBLIC},
{Key: conf.SSOClientId, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, {Key: conf.SSOClientId, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE},
{Key: conf.SSOClientSecret, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, {Key: conf.SSOClientSecret, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE},
{Key: conf.SSOOrganizationName, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, {Key: conf.SSOOrganizationName, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE},
{Key: conf.SSOApplicationName, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, {Key: conf.SSOApplicationName, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE},
{Key: conf.SSOEndpointName, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, {Key: conf.SSOEndpointName, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE},
{Key: conf.SSOJwtPublicKey, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, {Key: conf.SSOJwtPublicKey, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE},
{Key: conf.SSOAutoRegister, Value: "false", Type: conf.TypeBool, Group: model.SSO, Flag: model.PRIVATE},
{Key: conf.SSODefaultDir, Value: "/", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE},
{Key: conf.SSODefaultPermission, Value: "0", Type: conf.TypeNumber, Group: model.SSO, Flag: model.PRIVATE},
// qbittorrent settings // qbittorrent settings
{Key: conf.QbittorrentUrl, Value: "http://admin:adminadmin@localhost:8080/", Type: conf.TypeString, Group: model.SINGLE, Flag: model.PRIVATE}, {Key: conf.QbittorrentUrl, Value: "http://admin:adminadmin@localhost:8080/", Type: conf.TypeString, Group: model.SINGLE, Flag: model.PRIVATE},

View File

@ -57,14 +57,17 @@ const (
IndexProgress = "index_progress" IndexProgress = "index_progress"
//SSO //SSO
SSOClientId = "sso_client_id" SSOClientId = "sso_client_id"
SSOClientSecret = "sso_client_secret" SSOClientSecret = "sso_client_secret"
SSOLoginEnabled = "sso_login_enabled" SSOLoginEnabled = "sso_login_enabled"
SSOLoginplatform = "sso_login_platform" SSOLoginPlatform = "sso_login_platform"
SSOOrganizationName = "sso_organization_name" SSOOrganizationName = "sso_organization_name"
SSOApplicationName = "sso_application_name" SSOApplicationName = "sso_application_name"
SSOEndpointName = "sso_endpoint_name" SSOEndpointName = "sso_endpoint_name"
SSOJwtPublicKey = "sso_jwt_public_key" SSOJwtPublicKey = "sso_jwt_public_key"
SSOAutoRegister = "sso_auto_register"
SSODefaultDir = "sso_default_dir"
SSODefaultPermission = "sso_default_permission"
// qbittorrent // qbittorrent
QbittorrentUrl = "qbittorrent_url" QbittorrentUrl = "qbittorrent_url"

View File

@ -33,7 +33,7 @@ type User struct {
// 10: can add qbittorrent tasks // 10: can add qbittorrent tasks
Permission int32 `json:"permission"` Permission int32 `json:"permission"`
OtpSecret string `json:"-"` OtpSecret string `json:"-"`
SsoID string `json:"sso_id"` SsoID string `json:"sso_id"` // unique by sso platform
} }
func (u User) IsGuest() bool { func (u User) IsGuest() bool {

View File

@ -11,8 +11,10 @@ import (
"github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/conf"
"github.com/alist-org/alist/v3/internal/db" "github.com/alist-org/alist/v3/internal/db"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/setting" "github.com/alist-org/alist/v3/internal/setting"
"github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/pkg/utils"
"github.com/alist-org/alist/v3/pkg/utils/random"
"github.com/alist-org/alist/v3/server/common" "github.com/alist-org/alist/v3/server/common"
"github.com/coreos/go-oidc" "github.com/coreos/go-oidc"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -20,14 +22,15 @@ import (
"github.com/pquerna/otp" "github.com/pquerna/otp"
"github.com/pquerna/otp/totp" "github.com/pquerna/otp/totp"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gorm.io/gorm"
) )
var opts = totp.ValidateOpts{ var opts = totp.ValidateOpts{
// state verify won't expire in 30 secs, which is quite enough for the callback // state verify won't expire in 30 secs, which is quite enough for the callback
Period: 30, Period: 30,
Skew: 1, Skew: 1,
// in some OIDC providers(such as Authelia), state parameter must be at least 8 characters // in some OIDC providers(such as Authelia), state parameter must be at least 8 characters
Digits: otp.DigitsEight, Digits: otp.DigitsEight,
Algorithm: otp.AlgorithmSHA1, Algorithm: otp.AlgorithmSHA1,
} }
@ -35,7 +38,7 @@ func SSOLoginRedirect(c *gin.Context) {
method := c.Query("method") method := c.Query("method")
enabled := setting.GetBool(conf.SSOLoginEnabled) enabled := setting.GetBool(conf.SSOLoginEnabled)
clientId := setting.GetStr(conf.SSOClientId) clientId := setting.GetStr(conf.SSOClientId)
platform := setting.GetStr(conf.SSOLoginplatform) platform := setting.GetStr(conf.SSOLoginPlatform)
var r_url string var r_url string
var redirect_uri string var redirect_uri string
if enabled { if enabled {
@ -76,7 +79,7 @@ func SSOLoginRedirect(c *gin.Context) {
return return
} }
// generate state parameter // generate state parameter
state,err := totp.GenerateCodeCustom(base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts) state, err := totp.GenerateCodeCustom(base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts)
if err != nil { if err != nil {
common.ErrorStrResp(c, err.Error(), 400) common.ErrorStrResp(c, err.Error(), 400)
return return
@ -118,13 +121,39 @@ func GetOIDCClient(c *gin.Context) (*oauth2.Config, error) {
}, nil }, nil
} }
func autoRegister(username, userID string, err error) (*model.User, error) {
if !errors.Is(err, gorm.ErrRecordNotFound) || !setting.GetBool(conf.SSOAutoRegister) {
return nil, err
}
if username == "" {
return nil, errors.New("cannot get username from SSO provider")
}
user := &model.User{
ID: 0,
Username: username,
Password: random.String(16),
Permission: int32(setting.GetInt(conf.SSODefaultPermission, 0)),
BasePath: setting.GetStr(conf.SSODefaultDir),
Role: 0,
Disabled: false,
SsoID: userID,
}
if err = db.CreateUser(user); err != nil {
if strings.HasPrefix(err.Error(), "UNIQUE constraint failed") && strings.HasSuffix(err.Error(), "username") {
user.Username = user.Username + "_" + userID
if err = db.CreateUser(user); err != nil {
return nil, err
}
} else {
return nil, err
}
}
return user, nil
}
func OIDCLoginCallback(c *gin.Context) { func OIDCLoginCallback(c *gin.Context) {
argument := c.Query("method") argument := c.Query("method")
enabled := setting.GetBool(conf.SSOLoginEnabled)
clientId := setting.GetStr(conf.SSOClientId) clientId := setting.GetStr(conf.SSOClientId)
if !enabled {
common.ErrorResp(c, errors.New("invalid request"), 500)
}
endpoint := setting.GetStr(conf.SSOEndpointName) endpoint := setting.GetStr(conf.SSOEndpointName)
provider, err := oidc.NewProvider(c, endpoint) provider, err := oidc.NewProvider(c, endpoint)
if err != nil { if err != nil {
@ -170,7 +199,7 @@ func OIDCLoginCallback(c *gin.Context) {
} }
claims := UserInfo{} claims := UserInfo{}
if err := idToken.Claims(&claims); err != nil { if err := idToken.Claims(&claims); err != nil {
c.Error(err) common.ErrorResp(c, err, 400)
return return
} }
UserID := claims.Name UserID := claims.Name
@ -189,7 +218,10 @@ func OIDCLoginCallback(c *gin.Context) {
if argument == "sso_get_token" { if argument == "sso_get_token" {
user, err := db.GetUserBySSOID(UserID) user, err := db.GetUserBySSOID(UserID)
if err != nil { if err != nil {
common.ErrorResp(c, err, 400) user, err = autoRegister(UserID, UserID, err)
if err != nil {
common.ErrorResp(c, err, 400)
}
} }
token, err := common.GenerateToken(user.Username) token, err := common.GenerateToken(user.Username)
if err != nil { if err != nil {
@ -209,133 +241,145 @@ func OIDCLoginCallback(c *gin.Context) {
} }
func SSOLoginCallback(c *gin.Context) { func SSOLoginCallback(c *gin.Context) {
enabled := setting.GetBool(conf.SSOLoginEnabled)
if !enabled {
common.ErrorResp(c, errors.New("sso login is disabled"), 500)
}
argument := c.Query("method") argument := c.Query("method")
if argument == "get_sso_id" || argument == "sso_get_token" { if !utils.SliceContains([]string{"get_sso_id", "sso_get_token"}, argument) {
enabled := setting.GetBool(conf.SSOLoginEnabled) common.ErrorResp(c, errors.New("invalid request"), 500)
clientId := setting.GetStr(conf.SSOClientId) }
platform := setting.GetStr(conf.SSOLoginplatform) clientId := setting.GetStr(conf.SSOClientId)
clientSecret := setting.GetStr(conf.SSOClientSecret) platform := setting.GetStr(conf.SSOLoginPlatform)
var url1, url2, additionalbody, scope, authstring, idstring string clientSecret := setting.GetStr(conf.SSOClientSecret)
switch platform { var tokenUrl, userUrl, scope, authField, idField, usernameField string
case "Github": additionalForm := make(map[string]string)
url1 = "https://github.com/login/oauth/access_token" switch platform {
url2 = "https://api.github.com/user" case "Github":
additionalbody = "" tokenUrl = "https://github.com/login/oauth/access_token"
authstring = "code" userUrl = "https://api.github.com/user"
scope = "read:user" authField = "code"
idstring = "id" scope = "read:user"
case "Microsoft": idField = "id"
url1 = "https://login.microsoftonline.com/common/oauth2/v2.0/token" usernameField = "login"
url2 = "https://graph.microsoft.com/v1.0/me" case "Microsoft":
additionalbody = "&grant_type=authorization_code" tokenUrl = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
scope = "user.read" userUrl = "https://graph.microsoft.com/v1.0/me"
authstring = "code" additionalForm["grant_type"] = "authorization_code"
idstring = "id" scope = "user.read"
case "Google": authField = "code"
url1 = "https://oauth2.googleapis.com/token" idField = "id"
url2 = "https://www.googleapis.com/oauth2/v1/userinfo" usernameField = "displayName"
additionalbody = "&grant_type=authorization_code" case "Google":
scope = "https://www.googleapis.com/auth/userinfo.profile" tokenUrl = "https://oauth2.googleapis.com/token"
authstring = "code" userUrl = "https://www.googleapis.com/oauth2/v1/userinfo"
idstring = "id" additionalForm["grant_type"] = "authorization_code"
case "Dingtalk": scope = "https://www.googleapis.com/auth/userinfo.profile"
url1 = "https://api.dingtalk.com/v1.0/oauth2/userAccessToken" authField = "code"
url2 = "https://api.dingtalk.com/v1.0/contact/users/me" idField = "id"
authstring = "authCode" usernameField = "name"
idstring = "unionId" case "Dingtalk":
case "Casdoor": tokenUrl = "https://api.dingtalk.com/v1.0/oauth2/userAccessToken"
endpoint := strings.TrimSuffix(setting.GetStr(conf.SSOEndpointName), "/") userUrl = "https://api.dingtalk.com/v1.0/contact/users/me"
url1 = endpoint + "/api/login/oauth/access_token" authField = "authCode"
url2 = endpoint + "/api/userinfo" idField = "unionId"
additionalbody = "&grant_type=authorization_code" usernameField = "nick"
scope = "profile" case "Casdoor":
authstring = "code" endpoint := strings.TrimSuffix(setting.GetStr(conf.SSOEndpointName), "/")
idstring = "preferred_username" tokenUrl = endpoint + "/api/login/oauth/access_token"
case "OIDC": userUrl = endpoint + "/api/userinfo"
OIDCLoginCallback(c) additionalForm["grant_type"] = "authorization_code"
return scope = "profile"
default: authField = "code"
common.ErrorStrResp(c, "invalid platform", 400) idField = "sub"
return usernameField = "preferred_username"
} case "OIDC":
if enabled { OIDCLoginCallback(c)
callbackCode := c.Query(authstring) return
if callbackCode == "" { default:
common.ErrorStrResp(c, "No code provided", 400) common.ErrorStrResp(c, "invalid platform", 400)
return return
} }
var resp *resty.Response callbackCode := c.Query(authField)
var err error if callbackCode == "" {
if platform == "Dingtalk" { common.ErrorStrResp(c, "No code provided", 400)
resp, err = ssoClient.R().SetHeader("content-type", "application/json").SetHeader("Accept", "application/json"). return
SetBody(map[string]string{ }
"clientId": clientId, var resp *resty.Response
"clientSecret": clientSecret, var err error
"code": callbackCode, if platform == "Dingtalk" {
"grantType": "authorization_code", resp, err = ssoClient.R().SetHeader("content-type", "application/json").SetHeader("Accept", "application/json").
}). SetBody(map[string]string{
Post(url1) "clientId": clientId,
} else { "clientSecret": clientSecret,
resp, err = ssoClient.R().SetHeader("content-type", "application/x-www-form-urlencoded").SetHeader("Accept", "application/json"). "code": callbackCode,
SetBody("client_id=" + clientId + "&client_secret=" + clientSecret + "&code=" + callbackCode + "&redirect_uri=" + common.GetApiUrl(c.Request) + "/api/auth/sso_callback?method=" + argument + "&scope=" + scope + additionalbody). "grantType": "authorization_code",
Post(url1) }).
} Post(tokenUrl)
if err != nil { } else {
common.ErrorResp(c, err, 400) resp, err = ssoClient.R().SetHeader("Accept", "application/json").
return SetFormData(map[string]string{
} "client_id": clientId,
if platform == "Dingtalk" { "client_secret": clientSecret,
accessToken := utils.Json.Get(resp.Body(), "accessToken").ToString() "code": callbackCode,
resp, err = ssoClient.R().SetHeader("x-acs-dingtalk-access-token", accessToken). "redirect_uri": common.GetApiUrl(c.Request) + "/api/auth/sso_callback?method=" + argument,
Get(url2) "scope": scope,
} else { }).SetFormData(additionalForm).Post(tokenUrl)
accessToken := utils.Json.Get(resp.Body(), "access_token").ToString() }
resp, err = ssoClient.R().SetHeader("Authorization", "Bearer "+accessToken). if err != nil {
Get(url2) common.ErrorResp(c, err, 400)
} return
if err != nil { }
common.ErrorResp(c, err, 400) if platform == "Dingtalk" {
return accessToken := utils.Json.Get(resp.Body(), "accessToken").ToString()
} resp, err = ssoClient.R().SetHeader("x-acs-dingtalk-access-token", accessToken).
UserID := utils.Json.Get(resp.Body(), idstring).ToString() Get(userUrl)
if UserID == "0" { } else {
common.ErrorResp(c, errors.New("error occured"), 400) accessToken := utils.Json.Get(resp.Body(), "access_token").ToString()
return resp, err = ssoClient.R().SetHeader("Authorization", "Bearer "+accessToken).
} Get(userUrl)
if argument == "get_sso_id" { }
html := fmt.Sprintf(`<!DOCTYPE html> if err != nil {
common.ErrorResp(c, err, 400)
return
}
userID := utils.Json.Get(resp.Body(), idField).ToString()
if utils.SliceContains([]string{"", "0"}, userID) {
common.ErrorResp(c, errors.New("error occured"), 400)
return
}
if argument == "get_sso_id" {
html := fmt.Sprintf(`<!DOCTYPE html>
<head></head> <head></head>
<body> <body>
<script> <script>
window.opener.postMessage({"sso_id": "%s"}, "*") window.opener.postMessage({"sso_id": "%s"}, "*")
window.close() window.close()
</script> </script>
</body>`, UserID) </body>`, userID)
c.Data(200, "text/html; charset=utf-8", []byte(html)) c.Data(200, "text/html; charset=utf-8", []byte(html))
return return
} }
if argument == "sso_get_token" { username := utils.Json.Get(resp.Body(), usernameField).ToString()
user, err := db.GetUserBySSOID(UserID) user, err := db.GetUserBySSOID(userID)
if err != nil { if err != nil {
common.ErrorResp(c, err, 400) user, err = autoRegister(username, userID, err)
} if err != nil {
token, err := common.GenerateToken(user.Username) common.ErrorResp(c, err, 400)
if err != nil { return
common.ErrorResp(c, err, 400)
}
html := fmt.Sprintf(`<!DOCTYPE html>
<head></head>
<body>
<script>
window.opener.postMessage({"token":"%s"}, "*")
window.close()
</script>
</body>`, token)
c.Data(200, "text/html; charset=utf-8", []byte(html))
return
}
} else {
common.ErrorResp(c, errors.New("invalid request"), 500)
} }
} }
token, err := common.GenerateToken(user.Username)
if err != nil {
common.ErrorResp(c, err, 400)
}
html := fmt.Sprintf(`<!DOCTYPE html>
<head></head>
<body>
<script>
window.opener.postMessage({"token":"%s"}, "*")
window.close()
</script>
</body>`, token)
c.Data(200, "text/html; charset=utf-8", []byte(html))
} }