fix(sso): OIDC compatibility mode (#7524)

pull/7551/head
Mmx 2024-11-21 22:36:41 +08:00 committed by GitHub
parent 0ba754fd40
commit 150dcc2147
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 28 additions and 33 deletions

View File

@ -36,14 +36,21 @@ var opts = totp.ValidateOpts{
Algorithm: otp.AlgorithmSHA1, Algorithm: otp.AlgorithmSHA1,
} }
func ssoRedirectUri(c *gin.Context, useCompatibility bool, method string) string {
if useCompatibility {
return common.GetApiUrl(c.Request) + "/api/auth/" + method
} else {
return common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + method
}
}
func SSOLoginRedirect(c *gin.Context) { func SSOLoginRedirect(c *gin.Context) {
method := c.Query("method") method := c.Query("method")
usecompatibility := setting.GetBool(conf.SSOCompatibilityMode) useCompatibility := setting.GetBool(conf.SSOCompatibilityMode)
enabled := setting.GetBool(conf.SSOLoginEnabled) enabled := setting.GetBool(conf.SSOLoginEnabled)
clientId := setting.GetStr(conf.SSOClientId) clientId := setting.GetStr(conf.SSOClientId)
platform := setting.GetStr(conf.SSOLoginPlatform) platform := setting.GetStr(conf.SSOLoginPlatform)
var r_url string var rUrl string
var redirect_uri string
if !enabled { if !enabled {
common.ErrorStrResp(c, "Single sign-on is not enabled", 403) common.ErrorStrResp(c, "Single sign-on is not enabled", 403)
return return
@ -53,37 +60,33 @@ func SSOLoginRedirect(c *gin.Context) {
common.ErrorStrResp(c, "no method provided", 400) common.ErrorStrResp(c, "no method provided", 400)
return return
} }
if usecompatibility { redirectUri := ssoRedirectUri(c, useCompatibility, method)
redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/" + method
} else {
redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + method
}
urlValues.Add("response_type", "code") urlValues.Add("response_type", "code")
urlValues.Add("redirect_uri", redirect_uri) urlValues.Add("redirect_uri", redirectUri)
urlValues.Add("client_id", clientId) urlValues.Add("client_id", clientId)
switch platform { switch platform {
case "Github": case "Github":
r_url = "https://github.com/login/oauth/authorize?" rUrl = "https://github.com/login/oauth/authorize?"
urlValues.Add("scope", "read:user") urlValues.Add("scope", "read:user")
case "Microsoft": case "Microsoft":
r_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?" rUrl = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?"
urlValues.Add("scope", "user.read") urlValues.Add("scope", "user.read")
urlValues.Add("response_mode", "query") urlValues.Add("response_mode", "query")
case "Google": case "Google":
r_url = "https://accounts.google.com/o/oauth2/v2/auth?" rUrl = "https://accounts.google.com/o/oauth2/v2/auth?"
urlValues.Add("scope", "https://www.googleapis.com/auth/userinfo.profile") urlValues.Add("scope", "https://www.googleapis.com/auth/userinfo.profile")
case "Dingtalk": case "Dingtalk":
r_url = "https://login.dingtalk.com/oauth2/auth?" rUrl = "https://login.dingtalk.com/oauth2/auth?"
urlValues.Add("scope", "openid") urlValues.Add("scope", "openid")
urlValues.Add("prompt", "consent") urlValues.Add("prompt", "consent")
urlValues.Add("response_type", "code") urlValues.Add("response_type", "code")
case "Casdoor": case "Casdoor":
endpoint := strings.TrimSuffix(setting.GetStr(conf.SSOEndpointName), "/") endpoint := strings.TrimSuffix(setting.GetStr(conf.SSOEndpointName), "/")
r_url = endpoint + "/login/oauth/authorize?" rUrl = endpoint + "/login/oauth/authorize?"
urlValues.Add("scope", "profile") urlValues.Add("scope", "profile")
urlValues.Add("state", endpoint) urlValues.Add("state", endpoint)
case "OIDC": case "OIDC":
oauth2Config, err := GetOIDCClient(c) oauth2Config, err := GetOIDCClient(c, useCompatibility, redirectUri, method)
if err != nil { if err != nil {
common.ErrorStrResp(c, err.Error(), 400) common.ErrorStrResp(c, err.Error(), 400)
return return
@ -100,22 +103,14 @@ func SSOLoginRedirect(c *gin.Context) {
common.ErrorStrResp(c, "invalid platform", 400) common.ErrorStrResp(c, "invalid platform", 400)
return return
} }
c.Redirect(302, r_url+urlValues.Encode()) c.Redirect(302, rUrl+urlValues.Encode())
} }
var ssoClient = resty.New().SetRetryCount(3) var ssoClient = resty.New().SetRetryCount(3)
func GetOIDCClient(c *gin.Context) (*oauth2.Config, error) { func GetOIDCClient(c *gin.Context, useCompatibility bool, redirectUri, method string) (*oauth2.Config, error) {
var redirect_uri string if redirectUri == "" {
usecompatibility := setting.GetBool(conf.SSOCompatibilityMode) redirectUri = ssoRedirectUri(c, useCompatibility, method)
argument := c.Query("method")
if usecompatibility {
argument = path.Base(c.Request.URL.Path)
}
if usecompatibility {
redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/" + argument
} else {
redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + argument
} }
endpoint := setting.GetStr(conf.SSOEndpointName) endpoint := setting.GetStr(conf.SSOEndpointName)
provider, err := oidc.NewProvider(c, endpoint) provider, err := oidc.NewProvider(c, endpoint)
@ -127,7 +122,7 @@ func GetOIDCClient(c *gin.Context) (*oauth2.Config, error) {
return &oauth2.Config{ return &oauth2.Config{
ClientID: clientId, ClientID: clientId,
ClientSecret: clientSecret, ClientSecret: clientSecret,
RedirectURL: redirect_uri, RedirectURL: redirectUri,
// Discovery returns the OAuth2 endpoints. // Discovery returns the OAuth2 endpoints.
Endpoint: provider.Endpoint(), Endpoint: provider.Endpoint(),
@ -181,9 +176,9 @@ func parseJWT(p string) ([]byte, error) {
func OIDCLoginCallback(c *gin.Context) { func OIDCLoginCallback(c *gin.Context) {
useCompatibility := setting.GetBool(conf.SSOCompatibilityMode) useCompatibility := setting.GetBool(conf.SSOCompatibilityMode)
argument := c.Query("method") method := c.Query("method")
if useCompatibility { if useCompatibility {
argument = path.Base(c.Request.URL.Path) method = path.Base(c.Request.URL.Path)
} }
clientId := setting.GetStr(conf.SSOClientId) clientId := setting.GetStr(conf.SSOClientId)
endpoint := setting.GetStr(conf.SSOEndpointName) endpoint := setting.GetStr(conf.SSOEndpointName)
@ -192,7 +187,7 @@ func OIDCLoginCallback(c *gin.Context) {
common.ErrorResp(c, err, 400) common.ErrorResp(c, err, 400)
return return
} }
oauth2Config, err := GetOIDCClient(c) oauth2Config, err := GetOIDCClient(c, useCompatibility, "", method)
if err != nil { if err != nil {
common.ErrorResp(c, err, 400) common.ErrorResp(c, err, 400)
return return
@ -236,7 +231,7 @@ func OIDCLoginCallback(c *gin.Context) {
common.ErrorStrResp(c, "cannot get username from OIDC provider", 400) common.ErrorStrResp(c, "cannot get username from OIDC provider", 400)
return return
} }
if argument == "get_sso_id" { if method == "get_sso_id" {
if useCompatibility { if useCompatibility {
c.Redirect(302, common.GetApiUrl(c.Request)+"/@manage?sso_id="+userID) c.Redirect(302, common.GetApiUrl(c.Request)+"/@manage?sso_id="+userID)
return return
@ -252,7 +247,7 @@ func OIDCLoginCallback(c *gin.Context) {
c.Data(200, "text/html; charset=utf-8", []byte(html)) c.Data(200, "text/html; charset=utf-8", []byte(html))
return return
} }
if argument == "sso_get_token" { if method == "sso_get_token" {
user, err := db.GetUserBySSOID(userID) user, err := db.GetUserBySSOID(userID)
if err != nil { if err != nil {
user, err = autoRegister(userID, userID, err) user, err = autoRegister(userID, userID, err)