mirror of https://github.com/Xhofe/alist
feat(sso): generate and verify OAuth state with go-cache (#7527)
parent
12b429584e
commit
398c04386a
|
@ -1,10 +1,10 @@
|
||||||
package handles
|
package handles
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base32"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/Xhofe/go-cache"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path"
|
"path"
|
||||||
|
@ -21,19 +21,28 @@ import (
|
||||||
"github.com/coreos/go-oidc"
|
"github.com/coreos/go-oidc"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-resty/resty/v2"
|
"github.com/go-resty/resty/v2"
|
||||||
"github.com/pquerna/otp"
|
|
||||||
"github.com/pquerna/otp/totp"
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
var opts = totp.ValidateOpts{
|
const stateLength = 16
|
||||||
// state verify won't expire in 30 secs, which is quite enough for the callback
|
const stateExpire = time.Minute * 5
|
||||||
Period: 30,
|
|
||||||
Skew: 1,
|
var stateCache = cache.NewMemCache[string](cache.WithShards[string](stateLength))
|
||||||
// in some OIDC providers(such as Authelia), state parameter must be at least 8 characters
|
|
||||||
Digits: otp.DigitsEight,
|
func _keyState(clientID, state string) string {
|
||||||
Algorithm: otp.AlgorithmSHA1,
|
return fmt.Sprintf("%s_%s", clientID, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateState(clientID, ip string) string {
|
||||||
|
state := random.String(stateLength)
|
||||||
|
stateCache.Set(_keyState(clientID, state), ip, cache.WithEx[string](stateExpire))
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyState(clientID, ip, state string) bool {
|
||||||
|
value, ok := stateCache.Get(_keyState(clientID, state))
|
||||||
|
return ok && value == ip
|
||||||
}
|
}
|
||||||
|
|
||||||
func ssoRedirectUri(c *gin.Context, useCompatibility bool, method string) string {
|
func ssoRedirectUri(c *gin.Context, useCompatibility bool, method string) string {
|
||||||
|
@ -91,12 +100,7 @@ func SSOLoginRedirect(c *gin.Context) {
|
||||||
common.ErrorStrResp(c, err.Error(), 400)
|
common.ErrorStrResp(c, err.Error(), 400)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// generate state parameter
|
state := generateState(clientId, c.ClientIP())
|
||||||
state, err := totp.GenerateCodeCustom(base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts)
|
|
||||||
if err != nil {
|
|
||||||
common.ErrorStrResp(c, err.Error(), 400)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Redirect(http.StatusFound, oauth2Config.AuthCodeURL(state))
|
c.Redirect(http.StatusFound, oauth2Config.AuthCodeURL(state))
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
|
@ -192,13 +196,7 @@ func OIDCLoginCallback(c *gin.Context) {
|
||||||
common.ErrorResp(c, err, 400)
|
common.ErrorResp(c, err, 400)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// add state verify process
|
if !verifyState(clientId, c.ClientIP(), c.Query("state")) {
|
||||||
stateVerification, err := totp.ValidateCustom(c.Query("state"), base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts)
|
|
||||||
if err != nil {
|
|
||||||
common.ErrorResp(c, err, 400)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !stateVerification {
|
|
||||||
common.ErrorStrResp(c, "incorrect or expired state parameter", 400)
|
common.ErrorStrResp(c, "incorrect or expired state parameter", 400)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue