mirror of https://github.com/Xhofe/alist
fix(sso): OIDC compatibility mode (#7524)
parent
0ba754fd40
commit
150dcc2147
|
@ -36,14 +36,21 @@ var opts = totp.ValidateOpts{
|
||||||
Algorithm: otp.AlgorithmSHA1,
|
Algorithm: otp.AlgorithmSHA1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ssoRedirectUri(c *gin.Context, useCompatibility bool, method string) string {
|
||||||
|
if useCompatibility {
|
||||||
|
return common.GetApiUrl(c.Request) + "/api/auth/" + method
|
||||||
|
} else {
|
||||||
|
return common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + method
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func SSOLoginRedirect(c *gin.Context) {
|
func SSOLoginRedirect(c *gin.Context) {
|
||||||
method := c.Query("method")
|
method := c.Query("method")
|
||||||
usecompatibility := setting.GetBool(conf.SSOCompatibilityMode)
|
useCompatibility := setting.GetBool(conf.SSOCompatibilityMode)
|
||||||
enabled := setting.GetBool(conf.SSOLoginEnabled)
|
enabled := setting.GetBool(conf.SSOLoginEnabled)
|
||||||
clientId := setting.GetStr(conf.SSOClientId)
|
clientId := setting.GetStr(conf.SSOClientId)
|
||||||
platform := setting.GetStr(conf.SSOLoginPlatform)
|
platform := setting.GetStr(conf.SSOLoginPlatform)
|
||||||
var r_url string
|
var rUrl string
|
||||||
var redirect_uri string
|
|
||||||
if !enabled {
|
if !enabled {
|
||||||
common.ErrorStrResp(c, "Single sign-on is not enabled", 403)
|
common.ErrorStrResp(c, "Single sign-on is not enabled", 403)
|
||||||
return
|
return
|
||||||
|
@ -53,37 +60,33 @@ func SSOLoginRedirect(c *gin.Context) {
|
||||||
common.ErrorStrResp(c, "no method provided", 400)
|
common.ErrorStrResp(c, "no method provided", 400)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if usecompatibility {
|
redirectUri := ssoRedirectUri(c, useCompatibility, method)
|
||||||
redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/" + method
|
|
||||||
} else {
|
|
||||||
redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + method
|
|
||||||
}
|
|
||||||
urlValues.Add("response_type", "code")
|
urlValues.Add("response_type", "code")
|
||||||
urlValues.Add("redirect_uri", redirect_uri)
|
urlValues.Add("redirect_uri", redirectUri)
|
||||||
urlValues.Add("client_id", clientId)
|
urlValues.Add("client_id", clientId)
|
||||||
switch platform {
|
switch platform {
|
||||||
case "Github":
|
case "Github":
|
||||||
r_url = "https://github.com/login/oauth/authorize?"
|
rUrl = "https://github.com/login/oauth/authorize?"
|
||||||
urlValues.Add("scope", "read:user")
|
urlValues.Add("scope", "read:user")
|
||||||
case "Microsoft":
|
case "Microsoft":
|
||||||
r_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?"
|
rUrl = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?"
|
||||||
urlValues.Add("scope", "user.read")
|
urlValues.Add("scope", "user.read")
|
||||||
urlValues.Add("response_mode", "query")
|
urlValues.Add("response_mode", "query")
|
||||||
case "Google":
|
case "Google":
|
||||||
r_url = "https://accounts.google.com/o/oauth2/v2/auth?"
|
rUrl = "https://accounts.google.com/o/oauth2/v2/auth?"
|
||||||
urlValues.Add("scope", "https://www.googleapis.com/auth/userinfo.profile")
|
urlValues.Add("scope", "https://www.googleapis.com/auth/userinfo.profile")
|
||||||
case "Dingtalk":
|
case "Dingtalk":
|
||||||
r_url = "https://login.dingtalk.com/oauth2/auth?"
|
rUrl = "https://login.dingtalk.com/oauth2/auth?"
|
||||||
urlValues.Add("scope", "openid")
|
urlValues.Add("scope", "openid")
|
||||||
urlValues.Add("prompt", "consent")
|
urlValues.Add("prompt", "consent")
|
||||||
urlValues.Add("response_type", "code")
|
urlValues.Add("response_type", "code")
|
||||||
case "Casdoor":
|
case "Casdoor":
|
||||||
endpoint := strings.TrimSuffix(setting.GetStr(conf.SSOEndpointName), "/")
|
endpoint := strings.TrimSuffix(setting.GetStr(conf.SSOEndpointName), "/")
|
||||||
r_url = endpoint + "/login/oauth/authorize?"
|
rUrl = endpoint + "/login/oauth/authorize?"
|
||||||
urlValues.Add("scope", "profile")
|
urlValues.Add("scope", "profile")
|
||||||
urlValues.Add("state", endpoint)
|
urlValues.Add("state", endpoint)
|
||||||
case "OIDC":
|
case "OIDC":
|
||||||
oauth2Config, err := GetOIDCClient(c)
|
oauth2Config, err := GetOIDCClient(c, useCompatibility, redirectUri, method)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ErrorStrResp(c, err.Error(), 400)
|
common.ErrorStrResp(c, err.Error(), 400)
|
||||||
return
|
return
|
||||||
|
@ -100,22 +103,14 @@ func SSOLoginRedirect(c *gin.Context) {
|
||||||
common.ErrorStrResp(c, "invalid platform", 400)
|
common.ErrorStrResp(c, "invalid platform", 400)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Redirect(302, r_url+urlValues.Encode())
|
c.Redirect(302, rUrl+urlValues.Encode())
|
||||||
}
|
}
|
||||||
|
|
||||||
var ssoClient = resty.New().SetRetryCount(3)
|
var ssoClient = resty.New().SetRetryCount(3)
|
||||||
|
|
||||||
func GetOIDCClient(c *gin.Context) (*oauth2.Config, error) {
|
func GetOIDCClient(c *gin.Context, useCompatibility bool, redirectUri, method string) (*oauth2.Config, error) {
|
||||||
var redirect_uri string
|
if redirectUri == "" {
|
||||||
usecompatibility := setting.GetBool(conf.SSOCompatibilityMode)
|
redirectUri = ssoRedirectUri(c, useCompatibility, method)
|
||||||
argument := c.Query("method")
|
|
||||||
if usecompatibility {
|
|
||||||
argument = path.Base(c.Request.URL.Path)
|
|
||||||
}
|
|
||||||
if usecompatibility {
|
|
||||||
redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/" + argument
|
|
||||||
} else {
|
|
||||||
redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + argument
|
|
||||||
}
|
}
|
||||||
endpoint := setting.GetStr(conf.SSOEndpointName)
|
endpoint := setting.GetStr(conf.SSOEndpointName)
|
||||||
provider, err := oidc.NewProvider(c, endpoint)
|
provider, err := oidc.NewProvider(c, endpoint)
|
||||||
|
@ -127,7 +122,7 @@ func GetOIDCClient(c *gin.Context) (*oauth2.Config, error) {
|
||||||
return &oauth2.Config{
|
return &oauth2.Config{
|
||||||
ClientID: clientId,
|
ClientID: clientId,
|
||||||
ClientSecret: clientSecret,
|
ClientSecret: clientSecret,
|
||||||
RedirectURL: redirect_uri,
|
RedirectURL: redirectUri,
|
||||||
|
|
||||||
// Discovery returns the OAuth2 endpoints.
|
// Discovery returns the OAuth2 endpoints.
|
||||||
Endpoint: provider.Endpoint(),
|
Endpoint: provider.Endpoint(),
|
||||||
|
@ -181,9 +176,9 @@ func parseJWT(p string) ([]byte, error) {
|
||||||
|
|
||||||
func OIDCLoginCallback(c *gin.Context) {
|
func OIDCLoginCallback(c *gin.Context) {
|
||||||
useCompatibility := setting.GetBool(conf.SSOCompatibilityMode)
|
useCompatibility := setting.GetBool(conf.SSOCompatibilityMode)
|
||||||
argument := c.Query("method")
|
method := c.Query("method")
|
||||||
if useCompatibility {
|
if useCompatibility {
|
||||||
argument = path.Base(c.Request.URL.Path)
|
method = path.Base(c.Request.URL.Path)
|
||||||
}
|
}
|
||||||
clientId := setting.GetStr(conf.SSOClientId)
|
clientId := setting.GetStr(conf.SSOClientId)
|
||||||
endpoint := setting.GetStr(conf.SSOEndpointName)
|
endpoint := setting.GetStr(conf.SSOEndpointName)
|
||||||
|
@ -192,7 +187,7 @@ func OIDCLoginCallback(c *gin.Context) {
|
||||||
common.ErrorResp(c, err, 400)
|
common.ErrorResp(c, err, 400)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
oauth2Config, err := GetOIDCClient(c)
|
oauth2Config, err := GetOIDCClient(c, useCompatibility, "", method)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ErrorResp(c, err, 400)
|
common.ErrorResp(c, err, 400)
|
||||||
return
|
return
|
||||||
|
@ -236,7 +231,7 @@ func OIDCLoginCallback(c *gin.Context) {
|
||||||
common.ErrorStrResp(c, "cannot get username from OIDC provider", 400)
|
common.ErrorStrResp(c, "cannot get username from OIDC provider", 400)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if argument == "get_sso_id" {
|
if method == "get_sso_id" {
|
||||||
if useCompatibility {
|
if useCompatibility {
|
||||||
c.Redirect(302, common.GetApiUrl(c.Request)+"/@manage?sso_id="+userID)
|
c.Redirect(302, common.GetApiUrl(c.Request)+"/@manage?sso_id="+userID)
|
||||||
return
|
return
|
||||||
|
@ -252,7 +247,7 @@ func OIDCLoginCallback(c *gin.Context) {
|
||||||
c.Data(200, "text/html; charset=utf-8", []byte(html))
|
c.Data(200, "text/html; charset=utf-8", []byte(html))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if argument == "sso_get_token" {
|
if method == "sso_get_token" {
|
||||||
user, err := db.GetUserBySSOID(userID)
|
user, err := db.GetUserBySSOID(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
user, err = autoRegister(userID, userID, err)
|
user, err = autoRegister(userID, userID, err)
|
||||||
|
|
Loading…
Reference in New Issue