mirror of https://github.com/hashicorp/consul
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
210 lines
5.8 KiB
210 lines
5.8 KiB
// Copyright (c) HashiCorp, Inc. |
|
// SPDX-License-Identifier: MPL-2.0 |
|
|
|
package oidcauth |
|
|
|
import ( |
|
"context" |
|
"crypto/ecdsa" |
|
"crypto/ed25519" |
|
"crypto/rsa" |
|
"crypto/x509" |
|
"encoding/json" |
|
"encoding/pem" |
|
"errors" |
|
"fmt" |
|
"time" |
|
|
|
"gopkg.in/square/go-jose.v2/jwt" |
|
) |
|
|
|
const claimDefaultLeeway = 150 |
|
|
|
// ClaimsFromJWT is unrelated to the OIDC authorization code workflow. This |
|
// allows for a JWT to be directly validated and decoded into a set of claims. |
|
// |
|
// Requires the authenticator's config type be set to 'jwt'. |
|
func (a *Authenticator) ClaimsFromJWT(ctx context.Context, jwt string) (*Claims, error) { |
|
if a.config.authType() == authOIDCFlow { |
|
return nil, fmt.Errorf("ClaimsFromJWT is incompatible with type %q", TypeOIDC) |
|
} |
|
if jwt == "" { |
|
return nil, errors.New("missing jwt") |
|
} |
|
|
|
// Here is where things diverge. If it is using OIDC Discovery, validate that way; |
|
// otherwise validate against the locally configured or JWKS keys. Once things are |
|
// validated, we re-unify the request path when evaluating the claims. |
|
var ( |
|
allClaims map[string]interface{} |
|
err error |
|
) |
|
switch a.config.authType() { |
|
case authStaticKeys, authJWKS: |
|
allClaims, err = a.verifyVanillaJWT(ctx, jwt) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
case authOIDCDiscovery: |
|
allClaims, err = a.verifyOIDCToken(ctx, jwt) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
default: |
|
return nil, errors.New("unhandled case during login") |
|
} |
|
|
|
c, err := a.extractClaims(allClaims) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
if a.config.VerboseOIDCLogging && a.logger != nil { |
|
a.logger.Debug("OIDC provider response", "extracted_claims", c) |
|
} |
|
|
|
return c, nil |
|
} |
|
|
|
func (a *Authenticator) verifyVanillaJWT(ctx context.Context, loginToken string) (map[string]interface{}, error) { |
|
var ( |
|
allClaims = map[string]interface{}{} |
|
claims = jwt.Claims{} |
|
) |
|
// TODO(sso): handle JWTSupportedAlgs |
|
switch a.config.authType() { |
|
case authJWKS: |
|
// Verify signature (and only signature... other elements are checked later) |
|
payload, err := a.keySet.VerifySignature(ctx, loginToken) |
|
if err != nil { |
|
return nil, fmt.Errorf("error verifying token: %v", err) |
|
} |
|
|
|
// Unmarshal payload into two copies: public claims for library verification, and a set |
|
// of all received claims. |
|
if err := json.Unmarshal(payload, &claims); err != nil { |
|
return nil, fmt.Errorf("failed to unmarshal claims: %v", err) |
|
} |
|
if err := json.Unmarshal(payload, &allClaims); err != nil { |
|
return nil, fmt.Errorf("failed to unmarshal claims: %v", err) |
|
} |
|
case authStaticKeys: |
|
parsedJWT, err := jwt.ParseSigned(loginToken) |
|
if err != nil { |
|
return nil, fmt.Errorf("error parsing token: %v", err) |
|
} |
|
|
|
var valid bool |
|
for _, key := range a.parsedJWTPubKeys { |
|
if err := parsedJWT.Claims(key, &claims, &allClaims); err == nil { |
|
valid = true |
|
break |
|
} |
|
} |
|
if !valid { |
|
return nil, errors.New("no known key successfully validated the token signature") |
|
} |
|
default: |
|
return nil, fmt.Errorf("unsupported auth type for this verifyVanillaJWT: %d", a.config.authType()) |
|
} |
|
|
|
// We require notbefore or expiry; if only one is provided, we allow 5 minutes of leeway by default. |
|
// Configurable by ExpirationLeeway and NotBeforeLeeway |
|
if claims.IssuedAt == nil { |
|
claims.IssuedAt = new(jwt.NumericDate) |
|
} |
|
if claims.Expiry == nil { |
|
claims.Expiry = new(jwt.NumericDate) |
|
} |
|
if claims.NotBefore == nil { |
|
claims.NotBefore = new(jwt.NumericDate) |
|
} |
|
if *claims.IssuedAt == 0 && *claims.Expiry == 0 && *claims.NotBefore == 0 { |
|
return nil, errors.New("no issue time, notbefore, or expiration time encoded in token") |
|
} |
|
|
|
if *claims.Expiry == 0 { |
|
latestStart := *claims.IssuedAt |
|
if *claims.NotBefore > *claims.IssuedAt { |
|
latestStart = *claims.NotBefore |
|
} |
|
leeway := a.config.ExpirationLeeway.Seconds() |
|
if a.config.ExpirationLeeway.Seconds() < 0 { |
|
leeway = 0 |
|
} else if a.config.ExpirationLeeway.Seconds() == 0 { |
|
leeway = claimDefaultLeeway |
|
} |
|
*claims.Expiry = jwt.NumericDate(int64(latestStart) + int64(leeway)) |
|
} |
|
|
|
if *claims.NotBefore == 0 { |
|
if *claims.IssuedAt != 0 { |
|
*claims.NotBefore = *claims.IssuedAt |
|
} else { |
|
leeway := a.config.NotBeforeLeeway.Seconds() |
|
if a.config.NotBeforeLeeway.Seconds() < 0 { |
|
leeway = 0 |
|
} else if a.config.NotBeforeLeeway.Seconds() == 0 { |
|
leeway = claimDefaultLeeway |
|
} |
|
*claims.NotBefore = jwt.NumericDate(int64(*claims.Expiry) - int64(leeway)) |
|
} |
|
} |
|
|
|
expected := jwt.Expected{ |
|
Issuer: a.config.BoundIssuer, |
|
// Subject: a.config.BoundSubject, |
|
Time: time.Now(), |
|
} |
|
|
|
cksLeeway := a.config.ClockSkewLeeway |
|
if a.config.ClockSkewLeeway.Seconds() < 0 { |
|
cksLeeway = 0 |
|
} else if a.config.ClockSkewLeeway.Seconds() == 0 { |
|
cksLeeway = jwt.DefaultLeeway |
|
} |
|
|
|
if err := claims.ValidateWithLeeway(expected, cksLeeway); err != nil { |
|
return nil, fmt.Errorf("error validating claims: %v", err) |
|
} |
|
|
|
if err := validateAudience(a.config.BoundAudiences, claims.Audience, true); err != nil { |
|
return nil, fmt.Errorf("error validating claims: %v", err) |
|
} |
|
|
|
return allClaims, nil |
|
} |
|
|
|
// parsePublicKeyPEM is used to parse RSA, ECDSA, and Ed25519 public keys from PEMs |
|
// |
|
// Extracted from "github.com/hashicorp/vault/sdk/helper/certutil" |
|
// |
|
// go-sso added support for ed25519 (EdDSA) |
|
func parsePublicKeyPEM(data []byte) (interface{}, error) { |
|
block, _ := pem.Decode(data) |
|
if block != nil { |
|
var rawKey interface{} |
|
var err error |
|
if rawKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil { |
|
if cert, err := x509.ParseCertificate(block.Bytes); err == nil { |
|
rawKey = cert.PublicKey |
|
} else { |
|
return nil, err |
|
} |
|
} |
|
|
|
if rsaPublicKey, ok := rawKey.(*rsa.PublicKey); ok { |
|
return rsaPublicKey, nil |
|
} |
|
if ecPublicKey, ok := rawKey.(*ecdsa.PublicKey); ok { |
|
return ecPublicKey, nil |
|
} |
|
if edPublicKey, ok := rawKey.(ed25519.PublicKey); ok { |
|
return edPublicKey, nil |
|
} |
|
} |
|
|
|
return nil, errors.New("data does not contain any valid RSA, ECDSA, or ED25519 public keys") |
|
}
|
|
|