mirror of https://github.com/hashicorp/consul
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
511 lines
12 KiB
511 lines
12 KiB
package oidcauth |
|
|
|
import ( |
|
"context" |
|
"errors" |
|
"net/url" |
|
"strings" |
|
"testing" |
|
"time" |
|
|
|
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest" |
|
"github.com/hashicorp/go-hclog" |
|
"github.com/stretchr/testify/assert" |
|
"github.com/stretchr/testify/require" |
|
"gopkg.in/square/go-jose.v2/jwt" |
|
) |
|
|
|
func setupForOIDC(t *testing.T) (*Authenticator, *oidcauthtest.Server) { |
|
t.Helper() |
|
|
|
srv := oidcauthtest.Start(t) |
|
srv.SetClientCreds("abc", "def") |
|
|
|
config := &Config{ |
|
Type: TypeOIDC, |
|
OIDCDiscoveryURL: srv.Addr(), |
|
OIDCDiscoveryCACert: srv.CACert(), |
|
OIDCClientID: "abc", |
|
OIDCClientSecret: "def", |
|
OIDCACRValues: []string{"acr1", "acr2"}, |
|
JWTSupportedAlgs: []string{"ES256"}, |
|
BoundAudiences: []string{"abc"}, |
|
AllowedRedirectURIs: []string{"https://example.com"}, |
|
ClaimMappings: map[string]string{ |
|
"COLOR": "color", |
|
"/nested/Size": "size", |
|
"Age": "age", |
|
"Admin": "is_admin", |
|
"/nested/division": "division", |
|
"/nested/remote": "is_remote", |
|
"flavor": "flavor", // userinfo |
|
}, |
|
ListClaimMappings: map[string]string{ |
|
"/nested/Groups": "groups", |
|
}, |
|
} |
|
|
|
require.NoError(t, config.Validate()) |
|
|
|
oa, err := New(config, hclog.NewNullLogger()) |
|
require.NoError(t, err) |
|
t.Cleanup(oa.Stop) |
|
|
|
return oa, srv |
|
} |
|
|
|
func TestOIDC_AuthURL(t *testing.T) { |
|
t.Run("normal case", func(t *testing.T) { |
|
t.Parallel() |
|
|
|
oa, _ := setupForOIDC(t) |
|
|
|
authURL, err := oa.GetAuthCodeURL( |
|
context.Background(), |
|
"https://example.com", |
|
map[string]string{"foo": "bar"}, |
|
) |
|
require.NoError(t, err) |
|
|
|
require.True(t, strings.HasPrefix(authURL, oa.config.OIDCDiscoveryURL+"/auth?")) |
|
|
|
expected := map[string]string{ |
|
"client_id": "abc", |
|
"redirect_uri": "https://example.com", |
|
"response_type": "code", |
|
"scope": "openid", |
|
// optional values |
|
"acr_values": "acr1 acr2", |
|
} |
|
|
|
au, err := url.Parse(authURL) |
|
require.NoError(t, err) |
|
|
|
for k, v := range expected { |
|
assert.Equal(t, v, au.Query().Get(k), "key %q is incorrect", k) |
|
} |
|
|
|
assert.Regexp(t, `^[a-z0-9]{40}$`, au.Query().Get("nonce")) |
|
assert.Regexp(t, `^[a-z0-9]{40}$`, au.Query().Get("state")) |
|
|
|
}) |
|
|
|
t.Run("invalid RedirectURI", func(t *testing.T) { |
|
t.Parallel() |
|
|
|
oa, _ := setupForOIDC(t) |
|
|
|
_, err := oa.GetAuthCodeURL( |
|
context.Background(), |
|
"http://bitc0in-4-less.cx", |
|
map[string]string{"foo": "bar"}, |
|
) |
|
requireErrorContains(t, err, "unauthorized redirect_uri: http://bitc0in-4-less.cx") |
|
}) |
|
|
|
t.Run("missing RedirectURI", func(t *testing.T) { |
|
t.Parallel() |
|
|
|
oa, _ := setupForOIDC(t) |
|
|
|
_, err := oa.GetAuthCodeURL( |
|
context.Background(), |
|
"", |
|
map[string]string{"foo": "bar"}, |
|
) |
|
requireErrorContains(t, err, "missing redirect_uri") |
|
}) |
|
} |
|
|
|
func TestOIDC_JWT_Functions_Fail(t *testing.T) { |
|
oa, srv := setupForOIDC(t) |
|
|
|
cl := jwt.Claims{ |
|
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", |
|
Issuer: srv.Addr(), |
|
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)), |
|
Audience: jwt.Audience{"https://go-sso.test"}, |
|
} |
|
|
|
privateCl := struct { |
|
User string `json:"https://go-sso/user"` |
|
Groups []string `json:"https://go-sso/groups"` |
|
}{ |
|
"jeff", |
|
[]string{"foo", "bar"}, |
|
} |
|
|
|
jwtData, err := oidcauthtest.SignJWT("", cl, privateCl) |
|
require.NoError(t, err) |
|
|
|
_, err = oa.ClaimsFromJWT(context.Background(), jwtData) |
|
requireErrorContains(t, err, `ClaimsFromJWT is incompatible with type "oidc"`) |
|
} |
|
|
|
func TestOIDC_ClaimsFromAuthCode(t *testing.T) { |
|
requireProviderError := func(t *testing.T, err error) { |
|
var provErr *ProviderLoginFailedError |
|
if !errors.As(err, &provErr) { |
|
t.Fatalf("error was not a *ProviderLoginFailedError") |
|
} |
|
} |
|
requireTokenVerificationError := func(t *testing.T, err error) { |
|
var tokErr *TokenVerificationFailedError |
|
if !errors.As(err, &tokErr) { |
|
t.Fatalf("error was not a *TokenVerificationFailedError") |
|
} |
|
} |
|
|
|
t.Run("successful login", func(t *testing.T) { |
|
oa, srv := setupForOIDC(t) |
|
|
|
origPayload := map[string]string{"foo": "bar"} |
|
authURL, err := oa.GetAuthCodeURL( |
|
context.Background(), |
|
"https://example.com", |
|
origPayload, |
|
) |
|
require.NoError(t, err) |
|
|
|
state := getQueryParam(t, authURL, "state") |
|
nonce := getQueryParam(t, authURL, "nonce") |
|
|
|
// set provider claims that will be returned by the mock server |
|
srv.SetCustomClaims(sampleClaims(nonce)) |
|
|
|
// set mock provider's expected code |
|
srv.SetExpectedAuthCode("abc") |
|
|
|
claims, payload, err := oa.ClaimsFromAuthCode( |
|
context.Background(), |
|
state, "abc", |
|
) |
|
require.NoError(t, err) |
|
|
|
require.Equal(t, origPayload, payload) |
|
|
|
expectedClaims := &Claims{ |
|
Values: map[string]string{ |
|
"color": "green", |
|
"size": "medium", |
|
"age": "85", |
|
"is_admin": "true", |
|
"division": "3", |
|
"is_remote": "true", |
|
"flavor": "umami", // from userinfo |
|
}, |
|
Lists: map[string][]string{ |
|
"groups": {"a", "b"}, |
|
}, |
|
} |
|
|
|
require.Equal(t, expectedClaims, claims) |
|
}) |
|
|
|
t.Run("failed login unusable claims", func(t *testing.T) { |
|
oa, srv := setupForOIDC(t) |
|
|
|
origPayload := map[string]string{"foo": "bar"} |
|
authURL, err := oa.GetAuthCodeURL( |
|
context.Background(), |
|
"https://example.com", |
|
origPayload, |
|
) |
|
require.NoError(t, err) |
|
|
|
state := getQueryParam(t, authURL, "state") |
|
nonce := getQueryParam(t, authURL, "nonce") |
|
|
|
// set provider claims that will be returned by the mock server |
|
customClaims := sampleClaims(nonce) |
|
customClaims["COLOR"] = []interface{}{"yellow"} |
|
srv.SetCustomClaims(customClaims) |
|
|
|
// set mock provider's expected code |
|
srv.SetExpectedAuthCode("abc") |
|
|
|
_, _, err = oa.ClaimsFromAuthCode( |
|
context.Background(), |
|
state, "abc", |
|
) |
|
requireErrorContains(t, err, "error converting claim 'COLOR' to string from unknown type []interface {}") |
|
requireTokenVerificationError(t, err) |
|
}) |
|
|
|
t.Run("successful login - no userinfo", func(t *testing.T) { |
|
oa, srv := setupForOIDC(t) |
|
|
|
srv.DisableUserInfo() |
|
|
|
origPayload := map[string]string{"foo": "bar"} |
|
authURL, err := oa.GetAuthCodeURL( |
|
context.Background(), |
|
"https://example.com", |
|
origPayload, |
|
) |
|
require.NoError(t, err) |
|
|
|
state := getQueryParam(t, authURL, "state") |
|
nonce := getQueryParam(t, authURL, "nonce") |
|
|
|
// set provider claims that will be returned by the mock server |
|
srv.SetCustomClaims(sampleClaims(nonce)) |
|
|
|
// set mock provider's expected code |
|
srv.SetExpectedAuthCode("abc") |
|
|
|
claims, payload, err := oa.ClaimsFromAuthCode( |
|
context.Background(), |
|
state, "abc", |
|
) |
|
require.NoError(t, err) |
|
|
|
require.Equal(t, origPayload, payload) |
|
|
|
expectedClaims := &Claims{ |
|
Values: map[string]string{ |
|
"color": "green", |
|
"size": "medium", |
|
"age": "85", |
|
"is_admin": "true", |
|
"division": "3", |
|
"is_remote": "true", |
|
// "flavor": "umami", // from userinfo |
|
}, |
|
Lists: map[string][]string{ |
|
"groups": {"a", "b"}, |
|
}, |
|
} |
|
|
|
require.Equal(t, expectedClaims, claims) |
|
}) |
|
|
|
t.Run("failed login - bad nonce", func(t *testing.T) { |
|
t.Parallel() |
|
|
|
oa, srv := setupForOIDC(t) |
|
|
|
origPayload := map[string]string{"foo": "bar"} |
|
authURL, err := oa.GetAuthCodeURL( |
|
context.Background(), |
|
"https://example.com", |
|
origPayload, |
|
) |
|
require.NoError(t, err) |
|
|
|
state := getQueryParam(t, authURL, "state") |
|
|
|
srv.SetCustomClaims(sampleClaims("bad nonce")) |
|
|
|
// set mock provider's expected code |
|
srv.SetExpectedAuthCode("abc") |
|
|
|
_, _, err = oa.ClaimsFromAuthCode( |
|
context.Background(), |
|
state, "abc", |
|
) |
|
requireErrorContains(t, err, "Invalid ID token nonce") |
|
requireTokenVerificationError(t, err) |
|
}) |
|
|
|
t.Run("missing state", func(t *testing.T) { |
|
oa, _ := setupForOIDC(t) |
|
|
|
origPayload := map[string]string{"foo": "bar"} |
|
_, err := oa.GetAuthCodeURL( |
|
context.Background(), |
|
"https://example.com", |
|
origPayload, |
|
) |
|
require.NoError(t, err) |
|
|
|
_, _, err = oa.ClaimsFromAuthCode( |
|
context.Background(), |
|
"", "abc", |
|
) |
|
requireErrorContains(t, err, "Expired or missing OAuth state") |
|
requireProviderError(t, err) |
|
}) |
|
|
|
t.Run("unknown state", func(t *testing.T) { |
|
oa, _ := setupForOIDC(t) |
|
|
|
origPayload := map[string]string{"foo": "bar"} |
|
_, err := oa.GetAuthCodeURL( |
|
context.Background(), |
|
"https://example.com", |
|
origPayload, |
|
) |
|
require.NoError(t, err) |
|
|
|
_, _, err = oa.ClaimsFromAuthCode( |
|
context.Background(), |
|
"not_a_state", "abc", |
|
) |
|
requireErrorContains(t, err, "Expired or missing OAuth state") |
|
requireProviderError(t, err) |
|
}) |
|
|
|
t.Run("valid state, missing code", func(t *testing.T) { |
|
oa, _ := setupForOIDC(t) |
|
|
|
origPayload := map[string]string{"foo": "bar"} |
|
authURL, err := oa.GetAuthCodeURL( |
|
context.Background(), |
|
"https://example.com", |
|
origPayload, |
|
) |
|
require.NoError(t, err) |
|
|
|
state := getQueryParam(t, authURL, "state") |
|
|
|
_, _, err = oa.ClaimsFromAuthCode( |
|
context.Background(), |
|
state, "", |
|
) |
|
requireErrorContains(t, err, "OAuth code parameter not provided") |
|
requireProviderError(t, err) |
|
}) |
|
|
|
t.Run("failed code exchange", func(t *testing.T) { |
|
oa, srv := setupForOIDC(t) |
|
|
|
origPayload := map[string]string{"foo": "bar"} |
|
authURL, err := oa.GetAuthCodeURL( |
|
context.Background(), |
|
"https://example.com", |
|
origPayload, |
|
) |
|
require.NoError(t, err) |
|
|
|
state := getQueryParam(t, authURL, "state") |
|
|
|
// set mock provider's expected code |
|
srv.SetExpectedAuthCode("abc") |
|
|
|
_, _, err = oa.ClaimsFromAuthCode( |
|
context.Background(), |
|
state, "wrong_code", |
|
) |
|
requireErrorContains(t, err, "cannot fetch token") |
|
requireProviderError(t, err) |
|
}) |
|
|
|
t.Run("no id_token returned", func(t *testing.T) { |
|
oa, srv := setupForOIDC(t) |
|
|
|
origPayload := map[string]string{"foo": "bar"} |
|
authURL, err := oa.GetAuthCodeURL( |
|
context.Background(), |
|
"https://example.com", |
|
origPayload, |
|
) |
|
require.NoError(t, err) |
|
|
|
state := getQueryParam(t, authURL, "state") |
|
nonce := getQueryParam(t, authURL, "nonce") |
|
|
|
// set provider claims that will be returned by the mock server |
|
srv.SetCustomClaims(sampleClaims(nonce)) |
|
|
|
// set mock provider's expected code |
|
srv.SetExpectedAuthCode("abc") |
|
|
|
srv.OmitIDTokens() |
|
|
|
_, _, err = oa.ClaimsFromAuthCode( |
|
context.Background(), |
|
state, "abc", |
|
) |
|
requireErrorContains(t, err, "No id_token found in response") |
|
requireTokenVerificationError(t, err) |
|
}) |
|
|
|
t.Run("no response from provider", func(t *testing.T) { |
|
oa, srv := setupForOIDC(t) |
|
|
|
origPayload := map[string]string{"foo": "bar"} |
|
authURL, err := oa.GetAuthCodeURL( |
|
context.Background(), |
|
"https://example.com", |
|
origPayload, |
|
) |
|
require.NoError(t, err) |
|
|
|
state := getQueryParam(t, authURL, "state") |
|
|
|
// close the server prematurely |
|
srv.Stop() |
|
srv.SetExpectedAuthCode("abc") |
|
|
|
_, _, err = oa.ClaimsFromAuthCode( |
|
context.Background(), |
|
state, "abc", |
|
) |
|
requireErrorContains(t, err, "connection refused") |
|
requireProviderError(t, err) |
|
}) |
|
|
|
t.Run("invalid bound audience", func(t *testing.T) { |
|
oa, srv := setupForOIDC(t) |
|
|
|
srv.SetClientCreds("not_gonna_match", "def") |
|
|
|
origPayload := map[string]string{"foo": "bar"} |
|
authURL, err := oa.GetAuthCodeURL( |
|
context.Background(), |
|
"https://example.com", |
|
origPayload, |
|
) |
|
require.NoError(t, err) |
|
|
|
state := getQueryParam(t, authURL, "state") |
|
nonce := getQueryParam(t, authURL, "nonce") |
|
|
|
// set provider claims that will be returned by the mock server |
|
srv.SetCustomClaims(sampleClaims(nonce)) |
|
|
|
// set mock provider's expected code |
|
srv.SetExpectedAuthCode("abc") |
|
|
|
_, _, err = oa.ClaimsFromAuthCode( |
|
context.Background(), |
|
state, "abc", |
|
) |
|
requireErrorContains(t, err, `error validating signature: oidc: expected audience "abc" got ["not_gonna_match"]`) |
|
requireTokenVerificationError(t, err) |
|
}) |
|
} |
|
|
|
func sampleClaims(nonce string) map[string]interface{} { |
|
return map[string]interface{}{ |
|
"nonce": nonce, |
|
"email": "bob@example.com", |
|
"COLOR": "green", |
|
"sk": "42", |
|
"Age": 85, |
|
"Admin": true, |
|
"nested": map[string]interface{}{ |
|
"Size": "medium", |
|
"division": 3, |
|
"remote": true, |
|
"Groups": []string{"a", "b"}, |
|
"secret_code": "bar", |
|
}, |
|
"password": "foo", |
|
} |
|
} |
|
|
|
func getQueryParam(t *testing.T, inputURL, param string) string { |
|
t.Helper() |
|
|
|
m, err := url.ParseQuery(inputURL) |
|
if err != nil { |
|
t.Fatal(err) |
|
} |
|
v, ok := m[param] |
|
if !ok { |
|
t.Fatalf("query param %q not found", param) |
|
} |
|
return v[0] |
|
}
|
|
|