mirror of https://github.com/hashicorp/consul
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
366 lines
9.1 KiB
366 lines
9.1 KiB
// Copyright (c) HashiCorp, Inc. |
|
// SPDX-License-Identifier: BUSL-1.1 |
|
|
|
package leafcert |
|
|
|
import ( |
|
"bytes" |
|
"context" |
|
"crypto/rand" |
|
"crypto/x509" |
|
"encoding/pem" |
|
"errors" |
|
"fmt" |
|
"math/big" |
|
"sync" |
|
"sync/atomic" |
|
"testing" |
|
"time" |
|
|
|
"github.com/hashicorp/consul/agent/cacheshim" |
|
"github.com/hashicorp/consul/agent/connect" |
|
"github.com/hashicorp/consul/agent/structs" |
|
"github.com/hashicorp/consul/sdk/testutil" |
|
) |
|
|
|
// NewTestManager returns a *Manager that is pre-configured to use a mock RPC |
|
// implementation that can sign certs, and an in-memory CA roots reader that |
|
// interacts well with it. |
|
func NewTestManager(t *testing.T, mut func(*Config)) (*Manager, *TestSigner) { |
|
signer := newTestSigner(t, nil, nil) |
|
|
|
deps := Deps{ |
|
Logger: testutil.Logger(t), |
|
RootsReader: signer.RootsReader, |
|
CertSigner: signer, |
|
Config: Config{ |
|
// Override the root-change spread so we don't have to wait up to 20 seconds |
|
// to see root changes work. Can be changed back for specific tests that |
|
// need to test this, Note it's not 0 since that used default but is |
|
// effectively the same. |
|
TestOverrideCAChangeInitialDelay: 1 * time.Microsecond, |
|
}, |
|
} |
|
if mut != nil { |
|
mut(&deps.Config) |
|
} |
|
|
|
m := NewManager(deps) |
|
t.Cleanup(m.Stop) |
|
|
|
return m, signer |
|
} |
|
|
|
// TestSigner implements NetRPC and handles leaf signing operations |
|
type TestSigner struct { |
|
caLock sync.Mutex |
|
ca *structs.CARoot |
|
prevRoots []*structs.CARoot // remember prior ones |
|
|
|
IDGenerator *atomic.Uint64 |
|
RootsReader *testRootsReader |
|
|
|
signCallLock sync.Mutex |
|
signCallErrors []error |
|
signCallErrorCount uint64 |
|
signCallCapture []*structs.CASignRequest |
|
} |
|
|
|
var _ CertSigner = (*TestSigner)(nil) |
|
|
|
var ReplyWithExpiredCert = errors.New("reply with expired cert") |
|
|
|
func newTestSigner(t *testing.T, idGenerator *atomic.Uint64, rootsReader *testRootsReader) *TestSigner { |
|
if idGenerator == nil { |
|
idGenerator = &atomic.Uint64{} |
|
} |
|
if rootsReader == nil { |
|
rootsReader = newTestRootsReader(t) |
|
} |
|
s := &TestSigner{ |
|
IDGenerator: idGenerator, |
|
RootsReader: rootsReader, |
|
} |
|
return s |
|
} |
|
|
|
func (s *TestSigner) SetSignCallErrors(errs ...error) { |
|
s.signCallLock.Lock() |
|
defer s.signCallLock.Unlock() |
|
s.signCallErrors = append(s.signCallErrors, errs...) |
|
} |
|
|
|
func (s *TestSigner) GetSignCallErrorCount() uint64 { |
|
s.signCallLock.Lock() |
|
defer s.signCallLock.Unlock() |
|
return s.signCallErrorCount |
|
} |
|
|
|
func (s *TestSigner) UpdateCA(t *testing.T, ca *structs.CARoot) *structs.CARoot { |
|
if ca == nil { |
|
ca = connect.TestCA(t, nil) |
|
} |
|
roots := &structs.IndexedCARoots{ |
|
ActiveRootID: ca.ID, |
|
TrustDomain: connect.TestTrustDomain, |
|
Roots: []*structs.CARoot{ca}, |
|
QueryMeta: structs.QueryMeta{Index: s.nextIndex()}, |
|
} |
|
|
|
// Update the signer first. |
|
s.caLock.Lock() |
|
{ |
|
s.ca = ca |
|
roots.Roots = append(roots.Roots, s.prevRoots...) |
|
// Remember for the next rotation. |
|
dup := ca.Clone() |
|
dup.Active = false |
|
s.prevRoots = append(s.prevRoots, dup) |
|
} |
|
s.caLock.Unlock() |
|
|
|
// Then trigger an event when updating the roots. |
|
s.RootsReader.Set(roots) |
|
|
|
return ca |
|
} |
|
|
|
func (s *TestSigner) nextIndex() uint64 { |
|
return s.IDGenerator.Add(1) |
|
} |
|
|
|
func (s *TestSigner) getCA() *structs.CARoot { |
|
s.caLock.Lock() |
|
defer s.caLock.Unlock() |
|
return s.ca |
|
} |
|
|
|
func (s *TestSigner) GetCapture(idx int) *structs.CASignRequest { |
|
s.signCallLock.Lock() |
|
defer s.signCallLock.Unlock() |
|
if len(s.signCallCapture) > idx { |
|
return s.signCallCapture[idx] |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (s *TestSigner) SignCert(ctx context.Context, req *structs.CASignRequest) (*structs.IssuedCert, error) { |
|
useExpiredCert := false |
|
s.signCallLock.Lock() |
|
s.signCallCapture = append(s.signCallCapture, req) |
|
if len(s.signCallErrors) > 0 { |
|
err := s.signCallErrors[0] |
|
s.signCallErrors = s.signCallErrors[1:] |
|
if err == ReplyWithExpiredCert { |
|
useExpiredCert = true |
|
} else if err != nil { |
|
s.signCallErrorCount++ |
|
s.signCallLock.Unlock() |
|
return nil, err |
|
} |
|
} |
|
s.signCallLock.Unlock() |
|
|
|
// parts of this were inlined from CAManager and the connect ca provider |
|
ca := s.getCA() |
|
if ca == nil { |
|
return nil, fmt.Errorf("must call UpdateCA at least once") |
|
} |
|
|
|
csr, err := connect.ParseCSR(req.CSR) |
|
if err != nil { |
|
return nil, fmt.Errorf("error parsing CSR: %w", err) |
|
} |
|
|
|
connect.HackSANExtensionForCSR(csr) |
|
|
|
spiffeID, err := connect.ParseCertURI(csr.URIs[0]) |
|
if err != nil { |
|
return nil, fmt.Errorf("error parsing CSR URI: %w", err) |
|
} |
|
|
|
var isService bool |
|
var serviceID *connect.SpiffeIDService |
|
var workloadID *connect.SpiffeIDWorkloadIdentity |
|
|
|
switch spiffeID.(type) { |
|
case *connect.SpiffeIDService: |
|
isService = true |
|
serviceID = spiffeID.(*connect.SpiffeIDService) |
|
case *connect.SpiffeIDWorkloadIdentity: |
|
workloadID = spiffeID.(*connect.SpiffeIDWorkloadIdentity) |
|
default: |
|
return nil, fmt.Errorf("unexpected spiffeID type %T", spiffeID) |
|
} |
|
|
|
signer, err := connect.ParseSigner(ca.SigningKey) |
|
if err != nil { |
|
return nil, fmt.Errorf("error parsing CA signing key: %w", err) |
|
} |
|
|
|
keyId, err := connect.KeyId(signer.Public()) |
|
if err != nil { |
|
return nil, fmt.Errorf("error forming CA key id from public key: %w", err) |
|
} |
|
|
|
subjectKeyID, err := connect.KeyId(csr.PublicKey) |
|
if err != nil { |
|
return nil, fmt.Errorf("error forming subject key id from public key: %w", err) |
|
} |
|
|
|
caCert, err := connect.ParseCert(ca.RootCert) |
|
if err != nil { |
|
return nil, fmt.Errorf("error parsing CA root cert pem: %w", err) |
|
} |
|
|
|
const expiration = 10 * time.Minute |
|
|
|
now := time.Now() |
|
template := x509.Certificate{ |
|
SerialNumber: big.NewInt(int64(s.nextIndex())), |
|
URIs: csr.URIs, |
|
Signature: csr.Signature, |
|
// We use the correct signature algorithm for the CA key we are signing with |
|
// regardless of the algorithm used to sign the CSR signature above since |
|
// the leaf might use a different key type. |
|
SignatureAlgorithm: connect.SigAlgoForKey(signer), |
|
PublicKeyAlgorithm: csr.PublicKeyAlgorithm, |
|
PublicKey: csr.PublicKey, |
|
BasicConstraintsValid: true, |
|
KeyUsage: x509.KeyUsageDataEncipherment | |
|
x509.KeyUsageKeyAgreement | |
|
x509.KeyUsageDigitalSignature | |
|
x509.KeyUsageKeyEncipherment, |
|
ExtKeyUsage: []x509.ExtKeyUsage{ |
|
x509.ExtKeyUsageClientAuth, |
|
x509.ExtKeyUsageServerAuth, |
|
}, |
|
NotAfter: now.Add(expiration), |
|
NotBefore: now, |
|
AuthorityKeyId: keyId, |
|
SubjectKeyId: subjectKeyID, |
|
DNSNames: csr.DNSNames, |
|
IPAddresses: csr.IPAddresses, |
|
} |
|
|
|
if useExpiredCert { |
|
template.NotBefore = time.Now().Add(-13 * time.Hour) |
|
template.NotAfter = time.Now().Add(-1 * time.Hour) |
|
} |
|
|
|
// Create the certificate, PEM encode it and return that value. |
|
var buf bytes.Buffer |
|
bs, err := x509.CreateCertificate( |
|
rand.Reader, &template, caCert, csr.PublicKey, signer) |
|
if err != nil { |
|
return nil, fmt.Errorf("error creating cert pem from CSR: %w", err) |
|
} |
|
|
|
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs}) |
|
if err != nil { |
|
return nil, fmt.Errorf("error encoding cert pem into text: %w", err) |
|
} |
|
|
|
leafPEM := buf.String() |
|
|
|
leafCert, err := connect.ParseCert(leafPEM) |
|
if err != nil { |
|
return nil, fmt.Errorf("error parsing cert from generated leaf pem: %w", err) |
|
} |
|
|
|
index := s.nextIndex() |
|
if isService { |
|
// Service Spiffe ID case |
|
return &structs.IssuedCert{ |
|
SerialNumber: connect.EncodeSerialNumber(leafCert.SerialNumber), |
|
CertPEM: leafPEM, |
|
Service: serviceID.Service, |
|
ServiceURI: leafCert.URIs[0].String(), |
|
ValidAfter: leafCert.NotBefore, |
|
ValidBefore: leafCert.NotAfter, |
|
RaftIndex: structs.RaftIndex{ |
|
CreateIndex: index, |
|
ModifyIndex: index, |
|
}, |
|
}, nil |
|
} else { |
|
// Workload identity Spiffe ID case |
|
return &structs.IssuedCert{ |
|
SerialNumber: connect.EncodeSerialNumber(leafCert.SerialNumber), |
|
CertPEM: leafPEM, |
|
WorkloadIdentity: workloadID.WorkloadIdentity, |
|
WorkloadIdentityURI: leafCert.URIs[0].String(), |
|
ValidAfter: leafCert.NotBefore, |
|
ValidBefore: leafCert.NotAfter, |
|
RaftIndex: structs.RaftIndex{ |
|
CreateIndex: index, |
|
ModifyIndex: index, |
|
}, |
|
}, nil |
|
} |
|
} |
|
|
|
type testRootsReader struct { |
|
mu sync.Mutex |
|
index uint64 |
|
roots *structs.IndexedCARoots |
|
watcher chan struct{} |
|
} |
|
|
|
func newTestRootsReader(t *testing.T) *testRootsReader { |
|
r := &testRootsReader{ |
|
watcher: make(chan struct{}), |
|
} |
|
t.Cleanup(func() { |
|
r.mu.Lock() |
|
watcher := r.watcher |
|
r.mu.Unlock() |
|
close(watcher) |
|
}) |
|
return r |
|
} |
|
|
|
var _ RootsReader = (*testRootsReader)(nil) |
|
|
|
func (r *testRootsReader) Set(roots *structs.IndexedCARoots) { |
|
r.mu.Lock() |
|
oldWatcher := r.watcher |
|
r.watcher = make(chan struct{}) |
|
r.roots = roots |
|
if roots == nil { |
|
r.index = 1 |
|
} else { |
|
r.index = roots.Index |
|
} |
|
r.mu.Unlock() |
|
|
|
close(oldWatcher) |
|
} |
|
|
|
func (r *testRootsReader) Get() (*structs.IndexedCARoots, error) { |
|
r.mu.Lock() |
|
defer r.mu.Unlock() |
|
return r.roots, nil |
|
} |
|
|
|
func (r *testRootsReader) Notify(ctx context.Context, correlationID string, ch chan<- cacheshim.UpdateEvent) error { |
|
r.mu.Lock() |
|
watcher := r.watcher |
|
r.mu.Unlock() |
|
|
|
go func() { |
|
<-watcher |
|
|
|
r.mu.Lock() |
|
defer r.mu.Unlock() |
|
|
|
ch <- cacheshim.UpdateEvent{ |
|
CorrelationID: correlationID, |
|
Result: r.roots, |
|
Meta: cacheshim.ResultMeta{Index: r.index}, |
|
Err: nil, |
|
} |
|
}() |
|
return nil |
|
}
|
|
|