mirror of https://github.com/hashicorp/consul
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
289 lines
8.9 KiB
289 lines
8.9 KiB
// Copyright (c) HashiCorp, Inc. |
|
// SPDX-License-Identifier: BUSL-1.1 |
|
|
|
package oidcauth |
|
|
|
import ( |
|
"context" |
|
"encoding/json" |
|
"errors" |
|
"fmt" |
|
"strings" |
|
"time" |
|
|
|
"github.com/coreos/go-oidc/v3/oidc" |
|
"github.com/hashicorp/go-uuid" |
|
"golang.org/x/oauth2" |
|
) |
|
|
|
var ( |
|
oidcStateTimeout = 10 * time.Minute |
|
oidcStateCleanupInterval = 1 * time.Minute |
|
) |
|
|
|
// GetAuthCodeURL is the first part of the OIDC authorization code workflow. |
|
// The statePayload field is stored in the Authenticator instance keyed by the |
|
// "state" key so it can be returned during a future call to |
|
// ClaimsFromAuthCode. |
|
// |
|
// Requires the authenticator's config type be set to 'oidc'. |
|
func (a *Authenticator) GetAuthCodeURL(ctx context.Context, redirectURI string, statePayload interface{}) (string, error) { |
|
if a.config.authType() != authOIDCFlow { |
|
return "", fmt.Errorf("GetAuthCodeURL is incompatible with type %q", TypeJWT) |
|
} |
|
if redirectURI == "" { |
|
return "", errors.New("missing redirect_uri") |
|
} |
|
|
|
if !validRedirect(redirectURI, a.config.AllowedRedirectURIs) { |
|
return "", fmt.Errorf("unauthorized redirect_uri: %s", redirectURI) |
|
} |
|
|
|
// "openid" is a required scope for OpenID Connect flows |
|
scopes := append([]string{oidc.ScopeOpenID}, a.config.OIDCScopes...) |
|
|
|
// Configure an OpenID Connect aware OAuth2 client |
|
oauth2Config := oauth2.Config{ |
|
ClientID: a.config.OIDCClientID, |
|
ClientSecret: a.config.OIDCClientSecret, |
|
RedirectURL: redirectURI, |
|
Endpoint: a.provider.Endpoint(), |
|
Scopes: scopes, |
|
} |
|
|
|
stateID, nonce, err := a.createOIDCState(redirectURI, statePayload) |
|
if err != nil { |
|
return "", fmt.Errorf("error generating OAuth state: %v", err) |
|
} |
|
|
|
authCodeOpts := []oauth2.AuthCodeOption{ |
|
oidc.Nonce(nonce), |
|
} |
|
if len(a.config.OIDCACRValues) > 0 { |
|
authCodeOpts = append(authCodeOpts, oauth2.SetAuthURLParam("acr_values", strings.Join(a.config.OIDCACRValues, " "))) |
|
} |
|
|
|
return oauth2Config.AuthCodeURL(stateID, authCodeOpts...), nil |
|
} |
|
|
|
// ClaimsFromAuthCode is the second part of the OIDC authorization code |
|
// workflow. The interface{} return value is the statePayload previously passed |
|
// via GetAuthCodeURL. |
|
// |
|
// The error may be of type *ProviderLoginFailedError or |
|
// *TokenVerificationFailedError which can be detected via errors.As(). |
|
// |
|
// Requires the authenticator's config type be set to 'oidc'. |
|
func (a *Authenticator) ClaimsFromAuthCode(ctx context.Context, stateParam, code string) (*Claims, interface{}, error) { |
|
if a.config.authType() != authOIDCFlow { |
|
return nil, nil, fmt.Errorf("ClaimsFromAuthCode is incompatible with type %q", TypeJWT) |
|
} |
|
|
|
// TODO(sso): this could be because we ACTUALLY are getting OIDC error responses and |
|
// should handle them elsewhere! |
|
if code == "" { |
|
return nil, nil, &ProviderLoginFailedError{ |
|
Err: fmt.Errorf("OAuth code parameter not provided"), |
|
} |
|
} |
|
|
|
state := a.verifyOIDCState(stateParam) |
|
if state == nil { |
|
return nil, nil, &ProviderLoginFailedError{ |
|
Err: fmt.Errorf("Expired or missing OAuth state."), |
|
} |
|
} |
|
|
|
oidcCtx := contextWithHttpClient(ctx, a.httpClient) |
|
|
|
var oauth2Config = oauth2.Config{ |
|
ClientID: a.config.OIDCClientID, |
|
ClientSecret: a.config.OIDCClientSecret, |
|
RedirectURL: state.redirectURI, |
|
Endpoint: a.provider.Endpoint(), |
|
Scopes: []string{oidc.ScopeOpenID}, |
|
} |
|
|
|
oauth2Token, err := oauth2Config.Exchange(oidcCtx, code) |
|
if err != nil { |
|
return nil, nil, &ProviderLoginFailedError{ |
|
Err: fmt.Errorf("Error exchanging oidc code: %w", err), |
|
} |
|
} |
|
|
|
// Extract the ID Token from OAuth2 token. |
|
rawToken, ok := oauth2Token.Extra("id_token").(string) |
|
if !ok { |
|
return nil, nil, &TokenVerificationFailedError{ |
|
Err: errors.New("No id_token found in response."), |
|
} |
|
} |
|
|
|
if a.config.VerboseOIDCLogging && a.logger != nil { |
|
a.logger.Debug("OIDC provider response", "ID token", rawToken) |
|
} |
|
|
|
// Parse and verify ID Token payload. |
|
allClaims, err := a.verifyOIDCToken(ctx, rawToken) // TODO(sso): should this use oidcCtx? |
|
if err != nil { |
|
return nil, nil, &TokenVerificationFailedError{ |
|
Err: err, |
|
} |
|
} |
|
|
|
if allClaims["nonce"] != state.nonce { // TODO(sso): does this need a cast? |
|
return nil, nil, &TokenVerificationFailedError{ |
|
Err: errors.New("Invalid ID token nonce."), |
|
} |
|
} |
|
delete(allClaims, "nonce") |
|
|
|
// Attempt to fetch information from the /userinfo endpoint and merge it with |
|
// the existing claims data. A failure to fetch additional information from this |
|
// endpoint will not invalidate the authorization flow. |
|
if userinfo, err := a.provider.UserInfo(oidcCtx, oauth2.StaticTokenSource(oauth2Token)); err == nil { |
|
_ = userinfo.Claims(&allClaims) |
|
} else { |
|
if a.logger != nil { |
|
logFunc := a.logger.Warn |
|
if strings.Contains(err.Error(), "user info endpoint is not supported") { |
|
logFunc = a.logger.Info |
|
} |
|
logFunc("error reading /userinfo endpoint", "error", err) |
|
} |
|
} |
|
|
|
if a.config.VerboseOIDCLogging && a.logger != nil { |
|
if c, err := json.Marshal(allClaims); err == nil { |
|
a.logger.Debug("OIDC provider response", "claims", string(c)) |
|
} else { |
|
a.logger.Debug("OIDC provider response", "marshalling error", err.Error()) |
|
} |
|
} |
|
|
|
c, err := a.extractClaims(allClaims) |
|
if err != nil { |
|
return nil, nil, &TokenVerificationFailedError{ |
|
Err: err, |
|
} |
|
} |
|
|
|
if a.config.VerboseOIDCLogging && a.logger != nil { |
|
a.logger.Debug("OIDC provider response", "extracted_claims", c) |
|
} |
|
|
|
return c, state.payload, nil |
|
} |
|
|
|
// ProviderLoginFailedError is an error type sometimes returned from |
|
// ClaimsFromAuthCode(). |
|
// |
|
// It represents a failure to complete the authorization code workflow with the |
|
// provider such as losing important OIDC parameters or a failure to fetch an |
|
// id_token. |
|
// |
|
// You can check for it with errors.As(). |
|
type ProviderLoginFailedError struct { |
|
Err error |
|
} |
|
|
|
func (e *ProviderLoginFailedError) Error() string { |
|
return fmt.Sprintf("Provider login failed: %v", e.Err) |
|
} |
|
|
|
func (e *ProviderLoginFailedError) Unwrap() error { return e.Err } |
|
|
|
// TokenVerificationFailedError is an error type sometimes returned from |
|
// ClaimsFromAuthCode(). |
|
// |
|
// It represents a failure to vet the returned OIDC credentials for validity |
|
// such as the id_token not passing verification or using an mismatched nonce. |
|
// |
|
// You can check for it with errors.As(). |
|
type TokenVerificationFailedError struct { |
|
Err error |
|
} |
|
|
|
func (e *TokenVerificationFailedError) Error() string { |
|
return fmt.Sprintf("Token verification failed: %v", e.Err) |
|
} |
|
|
|
func (e *TokenVerificationFailedError) Unwrap() error { return e.Err } |
|
|
|
func (a *Authenticator) verifyOIDCToken(ctx context.Context, rawToken string) (map[string]interface{}, error) { |
|
allClaims := make(map[string]interface{}) |
|
|
|
oidcConfig := &oidc.Config{ |
|
SupportedSigningAlgs: a.config.JWTSupportedAlgs, |
|
} |
|
switch a.config.authType() { |
|
case authOIDCFlow: |
|
oidcConfig.ClientID = a.config.OIDCClientID |
|
case authOIDCDiscovery: |
|
oidcConfig.SkipClientIDCheck = true |
|
default: |
|
return nil, fmt.Errorf("unsupported auth type for this verifyOIDCToken: %d", a.config.authType()) |
|
} |
|
|
|
verifier := a.provider.Verifier(oidcConfig) |
|
|
|
idToken, err := verifier.Verify(ctx, rawToken) |
|
if err != nil { |
|
return nil, fmt.Errorf("error validating signature: %v", err) |
|
} |
|
|
|
if err := idToken.Claims(&allClaims); err != nil { |
|
return nil, fmt.Errorf("unable to successfully parse all claims from token: %v", err) |
|
} |
|
// Follows behavior of hashicorp/vault-plugin-auth-jwt (non-strict validation). |
|
// See https://developer.hashicorp.com/consul/docs/security/acl/auth-methods/oidc#oidc-configuration-troubleshooting. |
|
if err := validateAudience(a.config.BoundAudiences, idToken.Audience, false); err != nil { |
|
return nil, fmt.Errorf("error validating claims: %v", err) |
|
} |
|
|
|
return allClaims, nil |
|
} |
|
|
|
// verifyOIDCState tests whether the provided state ID is valid and returns the |
|
// associated state object if so. A nil state is returned if the ID is not found |
|
// or expired. The state should only ever be retrieved once and is deleted as |
|
// part of this request. |
|
func (a *Authenticator) verifyOIDCState(stateID string) *oidcState { |
|
defer a.oidcStates.Delete(stateID) |
|
|
|
if stateRaw, ok := a.oidcStates.Get(stateID); ok { |
|
return stateRaw.(*oidcState) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
// createOIDCState make an expiring state object, associated with a random state ID |
|
// that is passed throughout the OAuth process. A nonce is also included in the |
|
// auth process, and for simplicity will be identical in length/format as the state ID. |
|
func (a *Authenticator) createOIDCState(redirectURI string, payload interface{}) (string, string, error) { |
|
// Get enough bytes for 2 160-bit IDs (per rfc6749#section-10.10) |
|
bytes, err := uuid.GenerateRandomBytes(2 * 20) |
|
if err != nil { |
|
return "", "", err |
|
} |
|
|
|
stateID := fmt.Sprintf("%x", bytes[:20]) |
|
nonce := fmt.Sprintf("%x", bytes[20:]) |
|
|
|
a.oidcStates.SetDefault(stateID, &oidcState{ |
|
nonce: nonce, |
|
redirectURI: redirectURI, |
|
payload: payload, |
|
}) |
|
|
|
return stateID, nonce, nil |
|
} |
|
|
|
// oidcState is created when an authURL is requested. The state |
|
// identifier is passed throughout the OAuth process. |
|
type oidcState struct { |
|
nonce string |
|
redirectURI string |
|
payload interface{} |
|
}
|
|
|