feat(sso): generate and verify OAuth state with go-cache (#7527)

pull/7551/head
Mmx 2024-11-21 22:38:04 +08:00 committed by GitHub
parent 12b429584e
commit 398c04386a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 21 additions and 23 deletions

View File

@ -1,10 +1,10 @@
package handles package handles
import ( import (
"encoding/base32"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"github.com/Xhofe/go-cache"
"net/http" "net/http"
"net/url" "net/url"
"path" "path"
@ -21,19 +21,28 @@ import (
"github.com/coreos/go-oidc" "github.com/coreos/go-oidc"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-resty/resty/v2" "github.com/go-resty/resty/v2"
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gorm.io/gorm" "gorm.io/gorm"
) )
var opts = totp.ValidateOpts{ const stateLength = 16
// state verify won't expire in 30 secs, which is quite enough for the callback const stateExpire = time.Minute * 5
Period: 30,
Skew: 1, var stateCache = cache.NewMemCache[string](cache.WithShards[string](stateLength))
// in some OIDC providers(such as Authelia), state parameter must be at least 8 characters
Digits: otp.DigitsEight, func _keyState(clientID, state string) string {
Algorithm: otp.AlgorithmSHA1, return fmt.Sprintf("%s_%s", clientID, state)
}
func generateState(clientID, ip string) string {
state := random.String(stateLength)
stateCache.Set(_keyState(clientID, state), ip, cache.WithEx[string](stateExpire))
return state
}
func verifyState(clientID, ip, state string) bool {
value, ok := stateCache.Get(_keyState(clientID, state))
return ok && value == ip
} }
func ssoRedirectUri(c *gin.Context, useCompatibility bool, method string) string { func ssoRedirectUri(c *gin.Context, useCompatibility bool, method string) string {
@ -91,12 +100,7 @@ func SSOLoginRedirect(c *gin.Context) {
common.ErrorStrResp(c, err.Error(), 400) common.ErrorStrResp(c, err.Error(), 400)
return return
} }
// generate state parameter state := generateState(clientId, c.ClientIP())
state, err := totp.GenerateCodeCustom(base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts)
if err != nil {
common.ErrorStrResp(c, err.Error(), 400)
return
}
c.Redirect(http.StatusFound, oauth2Config.AuthCodeURL(state)) c.Redirect(http.StatusFound, oauth2Config.AuthCodeURL(state))
return return
default: default:
@ -192,13 +196,7 @@ func OIDCLoginCallback(c *gin.Context) {
common.ErrorResp(c, err, 400) common.ErrorResp(c, err, 400)
return return
} }
// add state verify process if !verifyState(clientId, c.ClientIP(), c.Query("state")) {
stateVerification, err := totp.ValidateCustom(c.Query("state"), base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts)
if err != nil {
common.ErrorResp(c, err, 400)
return
}
if !stateVerification {
common.ErrorStrResp(c, "incorrect or expired state parameter", 400) common.ErrorStrResp(c, "incorrect or expired state parameter", 400)
return return
} }