// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package oidcauth
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-hclog"
"github.com/mitchellh/pointerstructure"
"golang.org/x/oauth2"
"github.com/hashicorp/consul/internal/go-sso/oidcauth/internal/strutil"
)
func contextWithHttpClient ( ctx context . Context , client * http . Client ) context . Context {
return context . WithValue ( ctx , oauth2 . HTTPClient , client )
}
func createHTTPClient ( caCert string ) ( * http . Client , error ) {
tr := cleanhttp . DefaultPooledTransport ( )
if caCert != "" {
certPool := x509 . NewCertPool ( )
if ok := certPool . AppendCertsFromPEM ( [ ] byte ( caCert ) ) ; ! ok {
return nil , errors . New ( "could not parse CA PEM value successfully" )
}
tr . TLSClientConfig = & tls . Config {
RootCAs : certPool ,
}
}
return & http . Client {
Transport : tr ,
} , nil
}
// extractClaims extracts all configured claims from the received claims.
func ( a * Authenticator ) extractClaims ( allClaims map [ string ] interface { } ) ( * Claims , error ) {
metadata , err := extractStringMetadata ( a . logger , allClaims , a . config . ClaimMappings )
if err != nil {
return nil , err
}
listMetadata , err := extractListMetadata ( a . logger , allClaims , a . config . ListClaimMappings )
if err != nil {
return nil , err
}
return & Claims {
Values : metadata ,
Lists : listMetadata ,
} , nil
}
// extractStringMetadata builds a metadata map of string values from a set of
// claims and claims mappings. The referenced claims must be strings and the
// claims mappings must be of the structure:
//
// {
// "/some/claim/pointer": "metadata_key1",
// "another_claim": "metadata_key2",
// ...
// }
func extractStringMetadata ( logger hclog . Logger , allClaims map [ string ] interface { } , claimMappings map [ string ] string ) ( map [ string ] string , error ) {
metadata := make ( map [ string ] string )
for source , target := range claimMappings {
rawValue := getClaim ( logger , allClaims , source )
if rawValue == nil {
continue
}
strValue , ok := stringifyMetadataValue ( rawValue )
if ! ok {
return nil , fmt . Errorf ( "error converting claim '%s' to string from unknown type %T" , source , rawValue )
}
metadata [ target ] = strValue
}
return metadata , nil
}
// extractListMetadata builds a metadata map of string list values from a set
// of claims and claims mappings. The referenced claims must be strings and
// the claims mappings must be of the structure:
//
// {
// "/some/claim/pointer": "metadata_key1",
// "another_claim": "metadata_key2",
// ...
// }
func extractListMetadata ( logger hclog . Logger , allClaims map [ string ] interface { } , listClaimMappings map [ string ] string ) ( map [ string ] [ ] string , error ) {
out := make ( map [ string ] [ ] string )
for source , target := range listClaimMappings {
if rawValue := getClaim ( logger , allClaims , source ) ; rawValue != nil {
rawList , ok := normalizeList ( rawValue )
if ! ok {
return nil , fmt . Errorf ( "%q list claim could not be converted to string list" , source )
}
list := make ( [ ] string , 0 , len ( rawList ) )
for _ , raw := range rawList {
value , ok := stringifyMetadataValue ( raw )
if ! ok {
return nil , fmt . Errorf ( "value %v in %q list claim could not be parsed as string" , raw , source )
}
if value == "" {
continue
}
list = append ( list , value )
}
out [ target ] = list
}
}
return out , nil
}
// getClaim returns a claim value from allClaims given a provided claim string.
// If this string is a valid JSONPointer, it will be interpreted as such to
// locate the claim. Otherwise, the claim string will be used directly.
//
// There is no fixup done to the returned data type here. That happens a layer
// up in the caller.
func getClaim ( logger hclog . Logger , allClaims map [ string ] interface { } , claim string ) interface { } {
if ! strings . HasPrefix ( claim , "/" ) {
return allClaims [ claim ]
}
val , err := pointerstructure . Get ( allClaims , claim )
if err != nil {
if logger != nil {
logger . Warn ( "unable to locate claim" , "claim" , claim , "error" , err )
}
return nil
}
return val
}
// normalizeList takes an item or a slice and returns a slice. This is useful
// when providers are expected to return a list (typically of strings) but
// reduce it to a non-slice type when the list count is 1.
//
// There is no fixup done to elements of the returned slice here. That happens
// a layer up in the caller.
func normalizeList ( raw interface { } ) ( [ ] interface { } , bool ) {
switch v := raw . ( type ) {
case [ ] interface { } :
return v , true
case string , // note: this list should be the same as stringifyMetadataValue
bool ,
json . Number ,
float64 ,
float32 ,
int8 ,
int16 ,
int32 ,
int64 ,
int ,
uint8 ,
uint16 ,
uint32 ,
uint64 ,
uint :
return [ ] interface { } { v } , true
default :
return nil , false
}
}
// stringifyMetadataValue will try to convert the provided raw value into a
// faithful string representation of that value per these rules:
//
// - strings => unchanged
// - bool => "true" / "false"
// - json.Number => String()
// - float32/64 => truncated to int64 and then formatted as an ascii string
// - intXX/uintXX => casted to int64 and then formatted as an ascii string
//
// If successful the string value and true are returned. otherwise an empty
// string and false are returned.
func stringifyMetadataValue ( rawValue interface { } ) ( string , bool ) {
switch v := rawValue . ( type ) {
case string :
return v , true
case bool :
return strconv . FormatBool ( v ) , true
case json . Number :
return v . String ( ) , true
case float64 :
// The claims unmarshalled by go-oidc don't use UseNumber, so
// they'll come in as float64 instead of an integer or json.Number.
return strconv . FormatInt ( int64 ( v ) , 10 ) , true
// The numerical type cases following here are only here for the sake
// of numerical type completion. Everything is truncated to an integer
// before being stringified.
case float32 :
return strconv . FormatInt ( int64 ( v ) , 10 ) , true
case int8 :
return strconv . FormatInt ( int64 ( v ) , 10 ) , true
case int16 :
return strconv . FormatInt ( int64 ( v ) , 10 ) , true
case int32 :
return strconv . FormatInt ( int64 ( v ) , 10 ) , true
case int64 :
return strconv . FormatInt ( v , 10 ) , true
case int :
return strconv . FormatInt ( int64 ( v ) , 10 ) , true
case uint8 :
return strconv . FormatInt ( int64 ( v ) , 10 ) , true
case uint16 :
return strconv . FormatInt ( int64 ( v ) , 10 ) , true
case uint32 :
return strconv . FormatInt ( int64 ( v ) , 10 ) , true
case uint64 :
return strconv . FormatInt ( int64 ( v ) , 10 ) , true
case uint :
return strconv . FormatInt ( int64 ( v ) , 10 ) , true
default :
return "" , false
}
}
// validateAudience checks whether any of the audiences in audClaim match those
// in boundAudiences. If strict is true and there are no bound audiences, then
// the presence of any audience in the received claim is considered an error.
func validateAudience ( boundAudiences , audClaim [ ] string , strict bool ) error {
if strict && len ( boundAudiences ) == 0 && len ( audClaim ) > 0 {
return errors . New ( "audience claim found in JWT but no audiences are bound" )
}
if len ( boundAudiences ) > 0 {
for _ , v := range boundAudiences {
if strutil . StrListContains ( audClaim , v ) {
return nil
}
}
return errors . New ( "aud claim does not match any bound audience" )
}
return nil
}