consul/internal/go-sso/oidcauth/oidcauthtest/testing.go

511 lines
13 KiB
Go
Raw Normal View History

// Copyright (c) HashiCorp, Inc.
[COMPLIANCE] License changes (#18443) * Adding explicit MPL license for sub-package This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Adding explicit MPL license for sub-package This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Updating the license from MPL to Business Source License Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at <Blog URL>, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl. * add missing license headers * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 --------- Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com>
2023-08-11 13:12:13 +00:00
// SPDX-License-Identifier: BUSL-1.1
2020-05-12 01:59:29 +00:00
// package oidcauthtest exposes tools to assist in writing unit tests of OIDC
// and JWT authentication workflows.
//
// When the package is loaded it will randomly generate an ECDSA signing
// keypair used to sign JWTs both via the Server and the SignJWT method.
package oidcauthtest
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"io"
2020-05-12 01:59:29 +00:00
"log"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"sync"
"time"
2024-03-22 14:54:58 +00:00
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
2020-05-12 01:59:29 +00:00
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/internal/go-sso/oidcauth/internal/strutil"
2020-05-12 01:59:29 +00:00
)
// Server is local server the mocks the endpoints used by the OIDC and
// JWKS process.
type Server struct {
httpServer *httptest.Server
caCert string
returnFunc func()
jwks *jose.JSONWebKeySet
allowedRedirectURIs []string
replySubject string
replyUserinfo map[string]interface{}
mu sync.Mutex
clientID string
clientSecret string
expectedAuthCode string
expectedAuthNonce string
customClaims map[string]interface{}
customAudience string
omitIDToken bool
disableUserInfo bool
}
type TestingT interface {
require.TestingT
Helper()
Cleanup(func())
2020-05-12 01:59:29 +00:00
}
// Start creates a disposable Server. If the port provided is
// zero it will bind to a random free port, otherwise the provided port is
// used.
func Start(t TestingT) *Server {
t.Helper()
2020-05-12 01:59:29 +00:00
s := &Server{
allowedRedirectURIs: []string{
"https://example.com",
},
replySubject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
replyUserinfo: map[string]interface{}{
"color": "red",
"temperature": "76",
"flavor": "umami",
},
}
jwks, err := newJWKS(ecdsaPublicKey)
require.NoError(t, err)
s.jwks = jwks
s.httpServer = httptest.NewUnstartedServer(s)
s.httpServer.Config.ErrorLog = log.New(io.Discard, "", 0)
2020-05-12 01:59:29 +00:00
s.httpServer.StartTLS()
t.Cleanup(s.httpServer.Close)
cert := s.httpServer.Certificate()
var buf bytes.Buffer
require.NoError(t, pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}))
s.caCert = buf.String()
return s
}
// SetClientCreds is for configuring the client information required for the
// OIDC workflows.
func (s *Server) SetClientCreds(clientID, clientSecret string) {
s.mu.Lock()
defer s.mu.Unlock()
s.clientID = clientID
s.clientSecret = clientSecret
}
// SetExpectedAuthCode configures the auth code to return from /auth and the
// allowed auth code for /token.
func (s *Server) SetExpectedAuthCode(code string) {
s.mu.Lock()
defer s.mu.Unlock()
s.expectedAuthCode = code
}
// SetExpectedAuthNonce configures the nonce value required for /auth.
func (s *Server) SetExpectedAuthNonce(nonce string) {
s.mu.Lock()
defer s.mu.Unlock()
s.expectedAuthNonce = nonce
}
// SetAllowedRedirectURIs allows you to configure the allowed redirect URIs for
// the OIDC workflow. If not configured a sample of "https://example.com" is
// used.
func (s *Server) SetAllowedRedirectURIs(uris []string) {
s.mu.Lock()
defer s.mu.Unlock()
s.allowedRedirectURIs = uris
}
// SetCustomClaims lets you set claims to return in the JWT issued by the OIDC
// workflow.
func (s *Server) SetCustomClaims(customClaims map[string]interface{}) {
s.mu.Lock()
defer s.mu.Unlock()
s.customClaims = customClaims
}
// SetCustomAudience configures what audience value to embed in the JWT issued
// by the OIDC workflow.
func (s *Server) SetCustomAudience(customAudience string) {
s.mu.Lock()
defer s.mu.Unlock()
s.customAudience = customAudience
}
// OmitIDTokens forces an error state where the /token endpoint does not return
// id_token.
func (s *Server) OmitIDTokens() {
s.mu.Lock()
defer s.mu.Unlock()
s.omitIDToken = true
}
// DisableUserInfo makes the userinfo endpoint return 404 and omits it from the
// discovery config.
func (s *Server) DisableUserInfo() {
s.mu.Lock()
defer s.mu.Unlock()
s.disableUserInfo = true
}
// Stop stops the running Server.
func (s *Server) Stop() {
s.httpServer.Close()
}
// Addr returns the current base URL for the running webserver.
func (s *Server) Addr() string { return s.httpServer.URL }
// CACert returns the pem-encoded CA certificate used by the HTTPS server.
func (s *Server) CACert() string { return s.caCert }
// SigningKeys returns the pem-encoded keys used to sign JWTs.
func (s *Server) SigningKeys() (pub, priv string) {
return SigningKeys()
}
// ServeHTTP implements http.Handler.
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s.mu.Lock()
defer s.mu.Unlock()
w.Header().Set("Content-Type", "application/json")
switch req.URL.Path {
case "/.well-known/openid-configuration":
if req.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
reply := struct {
Issuer string `json:"issuer"`
AuthEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
JWKSURI string `json:"jwks_uri"`
UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"`
}{
Issuer: s.Addr(),
AuthEndpoint: s.Addr() + "/auth",
TokenEndpoint: s.Addr() + "/token",
JWKSURI: s.Addr() + "/certs",
UserinfoEndpoint: s.Addr() + "/userinfo",
}
if s.disableUserInfo {
reply.UserinfoEndpoint = ""
}
if err := writeJSON(w, &reply); err != nil {
return
}
case "/auth":
if req.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
qv := req.URL.Query()
if qv.Get("response_type") != "code" {
writeAuthErrorResponse(w, req, "unsupported_response_type", "")
return
}
if qv.Get("scope") != "openid" {
writeAuthErrorResponse(w, req, "invalid_scope", "")
return
}
if s.expectedAuthCode == "" {
writeAuthErrorResponse(w, req, "access_denied", "")
return
}
nonce := qv.Get("nonce")
if s.expectedAuthNonce != "" && s.expectedAuthNonce != nonce {
writeAuthErrorResponse(w, req, "access_denied", "")
return
}
state := qv.Get("state")
if state == "" {
writeAuthErrorResponse(w, req, "invalid_request", "missing state parameter")
return
}
redirectURI := qv.Get("redirect_uri")
if redirectURI == "" {
writeAuthErrorResponse(w, req, "invalid_request", "missing redirect_uri parameter")
return
}
redirectURI += "?state=" + url.QueryEscape(state) +
"&code=" + url.QueryEscape(s.expectedAuthCode)
http.Redirect(w, req, redirectURI, http.StatusFound)
return
case "/certs":
if req.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
if err := writeJSON(w, s.jwks); err != nil {
return
}
case "/certs_missing":
w.WriteHeader(http.StatusNotFound)
case "/certs_invalid":
w.Write([]byte("It's not a keyset!"))
case "/token":
if req.Method != "POST" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
switch {
case req.FormValue("grant_type") != "authorization_code":
_ = writeTokenErrorResponse(w, req, http.StatusBadRequest, "invalid_request", "bad grant_type")
return
case !strutil.StrListContains(s.allowedRedirectURIs, req.FormValue("redirect_uri")):
_ = writeTokenErrorResponse(w, req, http.StatusBadRequest, "invalid_request", "redirect_uri is not allowed")
return
case req.FormValue("code") != s.expectedAuthCode:
_ = writeTokenErrorResponse(w, req, http.StatusUnauthorized, "invalid_grant", "unexpected auth code")
return
}
stdClaims := jwt.Claims{
Subject: s.replySubject,
Issuer: s.Addr(),
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
Audience: jwt.Audience{s.clientID},
}
if s.customAudience != "" {
stdClaims.Audience = jwt.Audience{s.customAudience}
}
jwtData, err := SignJWT("", stdClaims, s.customClaims)
if err != nil {
_ = writeTokenErrorResponse(w, req, http.StatusInternalServerError, "server_error", err.Error())
return
}
reply := struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token,omitempty"`
}{
AccessToken: jwtData,
IDToken: jwtData,
}
if s.omitIDToken {
reply.IDToken = ""
}
if err := writeJSON(w, &reply); err != nil {
return
}
case "/userinfo":
if s.disableUserInfo {
w.WriteHeader(http.StatusNotFound)
return
}
if req.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
if err := writeJSON(w, s.replyUserinfo); err != nil {
return
}
default:
w.WriteHeader(http.StatusNotFound)
}
}
func writeAuthErrorResponse(w http.ResponseWriter, req *http.Request, errorCode, errorMessage string) {
qv := req.URL.Query()
redirectURI := qv.Get("redirect_uri") +
"?state=" + url.QueryEscape(qv.Get("state")) +
"&error=" + url.QueryEscape(errorCode)
if errorMessage != "" {
redirectURI += "&error_description=" + url.QueryEscape(errorMessage)
}
http.Redirect(w, req, redirectURI, http.StatusFound)
}
func writeTokenErrorResponse(w http.ResponseWriter, req *http.Request, statusCode int, errorCode, errorMessage string) error {
body := struct {
Code string `json:"error"`
Desc string `json:"error_description,omitempty"`
}{
Code: errorCode,
Desc: errorMessage,
}
w.WriteHeader(statusCode)
return writeJSON(w, &body)
}
// newJWKS converts a pem-encoded public key into JWKS data suitable for a
// verification endpoint response
func newJWKS(pubKey string) (*jose.JSONWebKeySet, error) {
block, _ := pem.Decode([]byte(pubKey))
if block == nil {
return nil, fmt.Errorf("unable to decode public key")
}
input := block.Bytes
pub, err := x509.ParsePKIXPublicKey(input)
if err != nil {
return nil, err
}
return &jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
{
2020-05-12 01:59:29 +00:00
Key: pub,
},
},
}, nil
}
func writeJSON(w http.ResponseWriter, out interface{}) error {
enc := json.NewEncoder(w)
return enc.Encode(out)
}
// SignJWT will bundle the provided claims into a signed JWT. The provided key
// is assumed to be ECDSA.
//
// If no private key is provided, the default package keys are used. These can
// be retrieved via the SigningKeys() method.
func SignJWT(privKey string, claims jwt.Claims, privateClaims interface{}) (string, error) {
if privKey == "" {
privKey = ecdsaPrivateKey
}
var key *ecdsa.PrivateKey
block, _ := pem.Decode([]byte(privKey))
if block != nil {
var err error
key, err = x509.ParseECPrivateKey(block.Bytes)
if err != nil {
return "", err
}
}
sig, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.ES256, Key: key},
(&jose.SignerOptions{}).WithType("JWT"),
)
if err != nil {
return "", err
}
raw, err := jwt.Signed(sig).
Claims(claims).
Claims(privateClaims).
CompactSerialize()
if err != nil {
return "", err
}
return raw, nil
}
// httptestNewUnstartedServerWithPort is roughly the same as
// httptest.NewUnstartedServer() but allows the caller to explicitly choose the
// port if desired.
func httptestNewUnstartedServerWithPort(handler http.Handler, port int) *httptest.Server {
if port == 0 {
return httptest.NewUnstartedServer(handler)
}
addr := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
l, err := net.Listen("tcp", addr)
if err != nil {
panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
}
return &httptest.Server{
Listener: l,
Config: &http.Server{Handler: handler},
}
}
// SigningKeys returns the pem-encoded keys used to sign JWTs by default.
func SigningKeys() (pub, priv string) {
return ecdsaPublicKey, ecdsaPrivateKey
}
var (
ecdsaPublicKey string
ecdsaPrivateKey string
)
func init() {
// Each time we run tests we generate a unique set of keys for use in the
// test. These are cached between runs but do not persist between restarts
// of the test binary.
var err error
ecdsaPublicKey, ecdsaPrivateKey, err = GenerateKey()
2020-05-12 01:59:29 +00:00
if err != nil {
panic(err)
}
}
func GenerateKey() (pub, priv string, err error) {
2020-05-12 01:59:29 +00:00
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return "", "", fmt.Errorf("error generating private key: %v", err)
}
{
derBytes, err := x509.MarshalECPrivateKey(privateKey)
if err != nil {
return "", "", fmt.Errorf("error marshaling private key: %v", err)
}
pemBlock := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: derBytes,
}
priv = string(pem.EncodeToMemory(pemBlock))
}
{
derBytes, err := x509.MarshalPKIXPublicKey(privateKey.Public())
if err != nil {
return "", "", fmt.Errorf("error marshaling public key: %v", err)
}
pemBlock := &pem.Block{
Type: "PUBLIC KEY",
Bytes: derBytes,
}
pub = string(pem.EncodeToMemory(pemBlock))
}
return pub, priv, nil
}