mirror of https://github.com/Xhofe/alist
fix(oidc): use TOTP as state verification to replace the static 'state' parameter (#4665)
parent
89832c296f
commit
12f40608e6
|
@ -1,11 +1,13 @@
|
||||||
package handles
|
package handles
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base32"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/alist-org/alist/v3/internal/conf"
|
"github.com/alist-org/alist/v3/internal/conf"
|
||||||
"github.com/alist-org/alist/v3/internal/db"
|
"github.com/alist-org/alist/v3/internal/db"
|
||||||
|
@ -15,9 +17,20 @@ import (
|
||||||
"github.com/coreos/go-oidc"
|
"github.com/coreos/go-oidc"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-resty/resty/v2"
|
"github.com/go-resty/resty/v2"
|
||||||
|
"github.com/pquerna/otp"
|
||||||
|
"github.com/pquerna/otp/totp"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var opts = totp.ValidateOpts{
|
||||||
|
// state verify won't expire in 30 secs, which is quite enough for the callback
|
||||||
|
Period: 30,
|
||||||
|
Skew: 1,
|
||||||
|
// in some OIDC providers(such as Authelia), state parameter must be at least 8 characters
|
||||||
|
Digits: otp.DigitsEight,
|
||||||
|
Algorithm: otp.AlgorithmSHA1,
|
||||||
|
}
|
||||||
|
|
||||||
func SSOLoginRedirect(c *gin.Context) {
|
func SSOLoginRedirect(c *gin.Context) {
|
||||||
method := c.Query("method")
|
method := c.Query("method")
|
||||||
enabled := setting.GetBool(conf.SSOLoginEnabled)
|
enabled := setting.GetBool(conf.SSOLoginEnabled)
|
||||||
|
@ -62,7 +75,13 @@ func SSOLoginRedirect(c *gin.Context) {
|
||||||
common.ErrorStrResp(c, err.Error(), 400)
|
common.ErrorStrResp(c, err.Error(), 400)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Redirect(http.StatusFound, oauth2Config.AuthCodeURL("state"))
|
// generate state parameter
|
||||||
|
state,err := totp.GenerateCodeCustom(base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts)
|
||||||
|
if err != nil {
|
||||||
|
common.ErrorStrResp(c, err.Error(), 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Redirect(http.StatusFound, oauth2Config.AuthCodeURL(state))
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
common.ErrorStrResp(c, "invalid platform", 400)
|
common.ErrorStrResp(c, "invalid platform", 400)
|
||||||
|
@ -117,6 +136,17 @@ func OIDCLoginCallback(c *gin.Context) {
|
||||||
common.ErrorResp(c, err, 400)
|
common.ErrorResp(c, err, 400)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// add state verify process
|
||||||
|
stateVerification, err := totp.ValidateCustom(c.Query("state"), base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts)
|
||||||
|
if err != nil {
|
||||||
|
common.ErrorResp(c, err, 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !stateVerification {
|
||||||
|
common.ErrorStrResp(c, "incorrect or expired state parameter", 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
oauth2Token, err := oauth2Config.Exchange(c, c.Query("code"))
|
oauth2Token, err := oauth2Config.Exchange(c, c.Query("code"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ErrorResp(c, err, 400)
|
common.ErrorResp(c, err, 400)
|
||||||
|
|
Loading…
Reference in New Issue