diff --git a/server/handles/ssologin.go b/server/handles/ssologin.go index 5b753442..76a7f5d0 100644 --- a/server/handles/ssologin.go +++ b/server/handles/ssologin.go @@ -1,11 +1,13 @@ package handles import ( + "encoding/base32" "errors" "fmt" "net/http" "net/url" "strings" + "time" "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/db" @@ -15,9 +17,20 @@ import ( "github.com/coreos/go-oidc" "github.com/gin-gonic/gin" "github.com/go-resty/resty/v2" + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" "golang.org/x/oauth2" ) +var opts = totp.ValidateOpts{ + // state verify won't expire in 30 secs, which is quite enough for the callback + Period: 30, + Skew: 1, + // in some OIDC providers(such as Authelia), state parameter must be at least 8 characters + Digits: otp.DigitsEight, + Algorithm: otp.AlgorithmSHA1, +} + func SSOLoginRedirect(c *gin.Context) { method := c.Query("method") enabled := setting.GetBool(conf.SSOLoginEnabled) @@ -62,7 +75,13 @@ func SSOLoginRedirect(c *gin.Context) { common.ErrorStrResp(c, err.Error(), 400) return } - c.Redirect(http.StatusFound, oauth2Config.AuthCodeURL("state")) + // generate state parameter + state,err := totp.GenerateCodeCustom(base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts) + if err != nil { + common.ErrorStrResp(c, err.Error(), 400) + return + } + c.Redirect(http.StatusFound, oauth2Config.AuthCodeURL(state)) return default: common.ErrorStrResp(c, "invalid platform", 400) @@ -117,6 +136,17 @@ func OIDCLoginCallback(c *gin.Context) { common.ErrorResp(c, err, 400) return } + // add state verify process + stateVerification, err := totp.ValidateCustom(c.Query("state"), base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + if !stateVerification { + common.ErrorStrResp(c, "incorrect or expired state parameter", 400) + return + } + oauth2Token, err := oauth2Config.Exchange(c, c.Query("code")) if err != nil { common.ErrorResp(c, err, 400)