mirror of https://github.com/hashicorp/consul
367 lines
9.1 KiB
Go
367 lines
9.1 KiB
Go
|
// Copyright (c) HashiCorp, Inc.
|
||
|
// SPDX-License-Identifier: BUSL-1.1
|
||
|
|
||
|
package leafcert
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"crypto/rand"
|
||
|
"crypto/x509"
|
||
|
"encoding/pem"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"math/big"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/hashicorp/consul/agent/cacheshim"
|
||
|
"github.com/hashicorp/consul/agent/connect"
|
||
|
"github.com/hashicorp/consul/agent/structs"
|
||
|
"github.com/hashicorp/consul/sdk/testutil"
|
||
|
)
|
||
|
|
||
|
// NewTestManager returns a *Manager that is pre-configured to use a mock RPC
|
||
|
// implementation that can sign certs, and an in-memory CA roots reader that
|
||
|
// interacts well with it.
|
||
|
func NewTestManager(t *testing.T, mut func(*Config)) (*Manager, *TestSigner) {
|
||
|
signer := newTestSigner(t, nil, nil)
|
||
|
|
||
|
deps := Deps{
|
||
|
Logger: testutil.Logger(t),
|
||
|
RootsReader: signer.RootsReader,
|
||
|
CertSigner: signer,
|
||
|
Config: Config{
|
||
|
// Override the root-change spread so we don't have to wait up to 20 seconds
|
||
|
// to see root changes work. Can be changed back for specific tests that
|
||
|
// need to test this, Note it's not 0 since that used default but is
|
||
|
// effectively the same.
|
||
|
TestOverrideCAChangeInitialDelay: 1 * time.Microsecond,
|
||
|
},
|
||
|
}
|
||
|
if mut != nil {
|
||
|
mut(&deps.Config)
|
||
|
}
|
||
|
|
||
|
m := NewManager(deps)
|
||
|
t.Cleanup(m.Stop)
|
||
|
|
||
|
return m, signer
|
||
|
}
|
||
|
|
||
|
// TestSigner implements NetRPC and handles leaf signing operations
|
||
|
type TestSigner struct {
|
||
|
caLock sync.Mutex
|
||
|
ca *structs.CARoot
|
||
|
prevRoots []*structs.CARoot // remember prior ones
|
||
|
|
||
|
IDGenerator *atomic.Uint64
|
||
|
RootsReader *testRootsReader
|
||
|
|
||
|
signCallLock sync.Mutex
|
||
|
signCallErrors []error
|
||
|
signCallErrorCount uint64
|
||
|
signCallCapture []*structs.CASignRequest
|
||
|
}
|
||
|
|
||
|
var _ CertSigner = (*TestSigner)(nil)
|
||
|
|
||
|
var ReplyWithExpiredCert = errors.New("reply with expired cert")
|
||
|
|
||
|
func newTestSigner(t *testing.T, idGenerator *atomic.Uint64, rootsReader *testRootsReader) *TestSigner {
|
||
|
if idGenerator == nil {
|
||
|
idGenerator = &atomic.Uint64{}
|
||
|
}
|
||
|
if rootsReader == nil {
|
||
|
rootsReader = newTestRootsReader(t)
|
||
|
}
|
||
|
s := &TestSigner{
|
||
|
IDGenerator: idGenerator,
|
||
|
RootsReader: rootsReader,
|
||
|
}
|
||
|
return s
|
||
|
}
|
||
|
|
||
|
func (s *TestSigner) SetSignCallErrors(errs ...error) {
|
||
|
s.signCallLock.Lock()
|
||
|
defer s.signCallLock.Unlock()
|
||
|
s.signCallErrors = append(s.signCallErrors, errs...)
|
||
|
}
|
||
|
|
||
|
func (s *TestSigner) GetSignCallErrorCount() uint64 {
|
||
|
s.signCallLock.Lock()
|
||
|
defer s.signCallLock.Unlock()
|
||
|
return s.signCallErrorCount
|
||
|
}
|
||
|
|
||
|
func (s *TestSigner) UpdateCA(t *testing.T, ca *structs.CARoot) *structs.CARoot {
|
||
|
if ca == nil {
|
||
|
ca = connect.TestCA(t, nil)
|
||
|
}
|
||
|
roots := &structs.IndexedCARoots{
|
||
|
ActiveRootID: ca.ID,
|
||
|
TrustDomain: connect.TestTrustDomain,
|
||
|
Roots: []*structs.CARoot{ca},
|
||
|
QueryMeta: structs.QueryMeta{Index: s.nextIndex()},
|
||
|
}
|
||
|
|
||
|
// Update the signer first.
|
||
|
s.caLock.Lock()
|
||
|
{
|
||
|
s.ca = ca
|
||
|
roots.Roots = append(roots.Roots, s.prevRoots...)
|
||
|
// Remember for the next rotation.
|
||
|
dup := ca.Clone()
|
||
|
dup.Active = false
|
||
|
s.prevRoots = append(s.prevRoots, dup)
|
||
|
}
|
||
|
s.caLock.Unlock()
|
||
|
|
||
|
// Then trigger an event when updating the roots.
|
||
|
s.RootsReader.Set(roots)
|
||
|
|
||
|
return ca
|
||
|
}
|
||
|
|
||
|
func (s *TestSigner) nextIndex() uint64 {
|
||
|
return s.IDGenerator.Add(1)
|
||
|
}
|
||
|
|
||
|
func (s *TestSigner) getCA() *structs.CARoot {
|
||
|
s.caLock.Lock()
|
||
|
defer s.caLock.Unlock()
|
||
|
return s.ca
|
||
|
}
|
||
|
|
||
|
func (s *TestSigner) GetCapture(idx int) *structs.CASignRequest {
|
||
|
s.signCallLock.Lock()
|
||
|
defer s.signCallLock.Unlock()
|
||
|
if len(s.signCallCapture) > idx {
|
||
|
return s.signCallCapture[idx]
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *TestSigner) SignCert(ctx context.Context, req *structs.CASignRequest) (*structs.IssuedCert, error) {
|
||
|
useExpiredCert := false
|
||
|
s.signCallLock.Lock()
|
||
|
s.signCallCapture = append(s.signCallCapture, req)
|
||
|
if len(s.signCallErrors) > 0 {
|
||
|
err := s.signCallErrors[0]
|
||
|
s.signCallErrors = s.signCallErrors[1:]
|
||
|
if err == ReplyWithExpiredCert {
|
||
|
useExpiredCert = true
|
||
|
} else if err != nil {
|
||
|
s.signCallErrorCount++
|
||
|
s.signCallLock.Unlock()
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
s.signCallLock.Unlock()
|
||
|
|
||
|
// parts of this were inlined from CAManager and the connect ca provider
|
||
|
ca := s.getCA()
|
||
|
if ca == nil {
|
||
|
return nil, fmt.Errorf("must call UpdateCA at least once")
|
||
|
}
|
||
|
|
||
|
csr, err := connect.ParseCSR(req.CSR)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("error parsing CSR: %w", err)
|
||
|
}
|
||
|
|
||
|
connect.HackSANExtensionForCSR(csr)
|
||
|
|
||
|
spiffeID, err := connect.ParseCertURI(csr.URIs[0])
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("error parsing CSR URI: %w", err)
|
||
|
}
|
||
|
|
||
|
var isService bool
|
||
|
var serviceID *connect.SpiffeIDService
|
||
|
var workloadID *connect.SpiffeIDWorkloadIdentity
|
||
|
|
||
|
switch spiffeID.(type) {
|
||
|
case *connect.SpiffeIDService:
|
||
|
isService = true
|
||
|
serviceID = spiffeID.(*connect.SpiffeIDService)
|
||
|
case *connect.SpiffeIDWorkloadIdentity:
|
||
|
workloadID = spiffeID.(*connect.SpiffeIDWorkloadIdentity)
|
||
|
default:
|
||
|
return nil, fmt.Errorf("unexpected spiffeID type %T", spiffeID)
|
||
|
}
|
||
|
|
||
|
signer, err := connect.ParseSigner(ca.SigningKey)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("error parsing CA signing key: %w", err)
|
||
|
}
|
||
|
|
||
|
keyId, err := connect.KeyId(signer.Public())
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("error forming CA key id from public key: %w", err)
|
||
|
}
|
||
|
|
||
|
subjectKeyID, err := connect.KeyId(csr.PublicKey)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("error forming subject key id from public key: %w", err)
|
||
|
}
|
||
|
|
||
|
caCert, err := connect.ParseCert(ca.RootCert)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("error parsing CA root cert pem: %w", err)
|
||
|
}
|
||
|
|
||
|
const expiration = 10 * time.Minute
|
||
|
|
||
|
now := time.Now()
|
||
|
template := x509.Certificate{
|
||
|
SerialNumber: big.NewInt(int64(s.nextIndex())),
|
||
|
URIs: csr.URIs,
|
||
|
Signature: csr.Signature,
|
||
|
// We use the correct signature algorithm for the CA key we are signing with
|
||
|
// regardless of the algorithm used to sign the CSR signature above since
|
||
|
// the leaf might use a different key type.
|
||
|
SignatureAlgorithm: connect.SigAlgoForKey(signer),
|
||
|
PublicKeyAlgorithm: csr.PublicKeyAlgorithm,
|
||
|
PublicKey: csr.PublicKey,
|
||
|
BasicConstraintsValid: true,
|
||
|
KeyUsage: x509.KeyUsageDataEncipherment |
|
||
|
x509.KeyUsageKeyAgreement |
|
||
|
x509.KeyUsageDigitalSignature |
|
||
|
x509.KeyUsageKeyEncipherment,
|
||
|
ExtKeyUsage: []x509.ExtKeyUsage{
|
||
|
x509.ExtKeyUsageClientAuth,
|
||
|
x509.ExtKeyUsageServerAuth,
|
||
|
},
|
||
|
NotAfter: now.Add(expiration),
|
||
|
NotBefore: now,
|
||
|
AuthorityKeyId: keyId,
|
||
|
SubjectKeyId: subjectKeyID,
|
||
|
DNSNames: csr.DNSNames,
|
||
|
IPAddresses: csr.IPAddresses,
|
||
|
}
|
||
|
|
||
|
if useExpiredCert {
|
||
|
template.NotBefore = time.Now().Add(-13 * time.Hour)
|
||
|
template.NotAfter = time.Now().Add(-1 * time.Hour)
|
||
|
}
|
||
|
|
||
|
// Create the certificate, PEM encode it and return that value.
|
||
|
var buf bytes.Buffer
|
||
|
bs, err := x509.CreateCertificate(
|
||
|
rand.Reader, &template, caCert, csr.PublicKey, signer)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("error creating cert pem from CSR: %w", err)
|
||
|
}
|
||
|
|
||
|
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("error encoding cert pem into text: %w", err)
|
||
|
}
|
||
|
|
||
|
leafPEM := buf.String()
|
||
|
|
||
|
leafCert, err := connect.ParseCert(leafPEM)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("error parsing cert from generated leaf pem: %w", err)
|
||
|
}
|
||
|
|
||
|
index := s.nextIndex()
|
||
|
if isService {
|
||
|
// Service Spiffe ID case
|
||
|
return &structs.IssuedCert{
|
||
|
SerialNumber: connect.EncodeSerialNumber(leafCert.SerialNumber),
|
||
|
CertPEM: leafPEM,
|
||
|
Service: serviceID.Service,
|
||
|
ServiceURI: leafCert.URIs[0].String(),
|
||
|
ValidAfter: leafCert.NotBefore,
|
||
|
ValidBefore: leafCert.NotAfter,
|
||
|
RaftIndex: structs.RaftIndex{
|
||
|
CreateIndex: index,
|
||
|
ModifyIndex: index,
|
||
|
},
|
||
|
}, nil
|
||
|
} else {
|
||
|
// Workload identity Spiffe ID case
|
||
|
return &structs.IssuedCert{
|
||
|
SerialNumber: connect.EncodeSerialNumber(leafCert.SerialNumber),
|
||
|
CertPEM: leafPEM,
|
||
|
WorkloadIdentity: workloadID.WorkloadIdentity,
|
||
|
WorkloadIdentityURI: leafCert.URIs[0].String(),
|
||
|
ValidAfter: leafCert.NotBefore,
|
||
|
ValidBefore: leafCert.NotAfter,
|
||
|
RaftIndex: structs.RaftIndex{
|
||
|
CreateIndex: index,
|
||
|
ModifyIndex: index,
|
||
|
},
|
||
|
}, nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type testRootsReader struct {
|
||
|
mu sync.Mutex
|
||
|
index uint64
|
||
|
roots *structs.IndexedCARoots
|
||
|
watcher chan struct{}
|
||
|
}
|
||
|
|
||
|
func newTestRootsReader(t *testing.T) *testRootsReader {
|
||
|
r := &testRootsReader{
|
||
|
watcher: make(chan struct{}),
|
||
|
}
|
||
|
t.Cleanup(func() {
|
||
|
r.mu.Lock()
|
||
|
watcher := r.watcher
|
||
|
r.mu.Unlock()
|
||
|
close(watcher)
|
||
|
})
|
||
|
return r
|
||
|
}
|
||
|
|
||
|
var _ RootsReader = (*testRootsReader)(nil)
|
||
|
|
||
|
func (r *testRootsReader) Set(roots *structs.IndexedCARoots) {
|
||
|
r.mu.Lock()
|
||
|
oldWatcher := r.watcher
|
||
|
r.watcher = make(chan struct{})
|
||
|
r.roots = roots
|
||
|
if roots == nil {
|
||
|
r.index = 1
|
||
|
} else {
|
||
|
r.index = roots.Index
|
||
|
}
|
||
|
r.mu.Unlock()
|
||
|
|
||
|
close(oldWatcher)
|
||
|
}
|
||
|
|
||
|
func (r *testRootsReader) Get() (*structs.IndexedCARoots, error) {
|
||
|
r.mu.Lock()
|
||
|
defer r.mu.Unlock()
|
||
|
return r.roots, nil
|
||
|
}
|
||
|
|
||
|
func (r *testRootsReader) Notify(ctx context.Context, correlationID string, ch chan<- cacheshim.UpdateEvent) error {
|
||
|
r.mu.Lock()
|
||
|
watcher := r.watcher
|
||
|
r.mu.Unlock()
|
||
|
|
||
|
go func() {
|
||
|
<-watcher
|
||
|
|
||
|
r.mu.Lock()
|
||
|
defer r.mu.Unlock()
|
||
|
|
||
|
ch <- cacheshim.UpdateEvent{
|
||
|
CorrelationID: correlationID,
|
||
|
Result: r.roots,
|
||
|
Meta: cacheshim.ResultMeta{Index: r.index},
|
||
|
Err: nil,
|
||
|
}
|
||
|
}()
|
||
|
return nil
|
||
|
}
|