Ensure certificates retrieved through the cache get persisted with auto-config (#8409)

pull/8412/head
Matt Keeler 2020-07-30 11:37:18 -04:00 committed by GitHub
parent dbb461a5d3
commit 1a78cf9b4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 198 additions and 30 deletions

View File

@ -544,7 +544,8 @@ func New(options ...AgentOption) (*Agent, error) {
WithNodeName(a.config.NodeName). WithNodeName(a.config.NodeName).
WithFallback(a.autoConfigFallbackTLS). WithFallback(a.autoConfigFallbackTLS).
WithLogger(a.logger.Named(logging.AutoConfig)). WithLogger(a.logger.Named(logging.AutoConfig)).
WithTokens(a.tokens) WithTokens(a.tokens).
WithPersistence(a.autoConfigPersist)
acCertMon, err := certmon.New(cmConf) acCertMon, err := certmon.New(cmConf)
if err != nil { if err != nil {
return nil, err return nil, err
@ -888,9 +889,19 @@ func (a *Agent) autoEncryptInitialCertificate(ctx context.Context) (*structs.Sig
} }
func (a *Agent) autoConfigFallbackTLS(ctx context.Context) (*structs.SignedResponse, error) { func (a *Agent) autoConfigFallbackTLS(ctx context.Context) (*structs.SignedResponse, error) {
if a.autoConf == nil {
return nil, fmt.Errorf("AutoConfig manager has not been created yet")
}
return a.autoConf.FallbackTLS(ctx) return a.autoConf.FallbackTLS(ctx)
} }
func (a *Agent) autoConfigPersist(resp *structs.SignedResponse) error {
if a.autoConf == nil {
return fmt.Errorf("AutoConfig manager has not been created yet")
}
return a.autoConf.RecordUpdatedCerts(resp)
}
func (a *Agent) listenAndServeGRPC() error { func (a *Agent) listenAndServeGRPC() error {
if len(a.config.GRPCAddrs) < 1 { if len(a.config.GRPCAddrs) < 1 {
return nil return nil

View File

@ -21,6 +21,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/golang/protobuf/jsonpb"
"github.com/google/tcpproxy" "github.com/google/tcpproxy"
"github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/cache"
cachetype "github.com/hashicorp/consul/agent/cache-types" cachetype "github.com/hashicorp/consul/agent/cache-types"
@ -32,6 +33,7 @@ import (
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest" "github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
"github.com/hashicorp/consul/ipaddr" "github.com/hashicorp/consul/ipaddr"
"github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/proto/pbautoconf"
"github.com/hashicorp/consul/sdk/freeport" "github.com/hashicorp/consul/sdk/freeport"
"github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/sdk/testutil/retry"
@ -4722,7 +4724,13 @@ func TestAutoConfig_Integration(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
client := StartTestAgent(t, TestAgent{Name: "test-client", HCL: ` client := StartTestAgent(t, TestAgent{Name: "test-client",
Overrides: `
connect {
test_ca_leaf_root_change_spread = "1ns"
}
`,
HCL: `
bootstrap = false bootstrap = false
server = false server = false
ca_file = "` + caFile + `" ca_file = "` + caFile + `"
@ -4736,7 +4744,8 @@ func TestAutoConfig_Integration(t *testing.T) {
enabled = true enabled = true
intro_token = "` + token + `" intro_token = "` + token + `"
server_addresses = ["` + srv.Config.RPCBindAddr.String() + `"] server_addresses = ["` + srv.Config.RPCBindAddr.String() + `"]
}`}) }`,
})
defer client.Shutdown() defer client.Shutdown()
@ -4776,6 +4785,21 @@ func TestAutoConfig_Integration(t *testing.T) {
// ensure that a new cert gets generated and pushed into the TLS configurator // ensure that a new cert gets generated and pushed into the TLS configurator
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
require.NotEqual(r, cert1, client.Agent.tlsConfigurator.Cert()) require.NotEqual(r, cert1, client.Agent.tlsConfigurator.Cert())
// check that the on disk certs match expectations
data, err := ioutil.ReadFile(filepath.Join(client.DataDir, "auto-config.json"))
require.NoError(r, err)
rdr := strings.NewReader(string(data))
var resp pbautoconf.AutoConfigResponse
pbUnmarshaler := &jsonpb.Unmarshaler{
AllowUnknownFields: false,
}
require.NoError(r, pbUnmarshaler.Unmarshal(rdr, &resp), "data: %s", data)
actual, err := tls.X509KeyPair([]byte(resp.Certificate.CertPEM), []byte(resp.Certificate.PrivateKeyPEM))
require.NoError(r, err)
require.Equal(r, client.Agent.tlsConfigurator.Cert(), &actual)
}) })
// spot check that we now have an ACL token // spot check that we now have an ACL token

View File

@ -62,6 +62,7 @@ type AutoConfig struct {
overrides []config.Source overrides []config.Source
certMonitor CertMonitor certMonitor CertMonitor
config *config.RuntimeConfig config *config.RuntimeConfig
autoConfigResponse *pbautoconf.AutoConfigResponse
autoConfigData string autoConfigData string
cancel context.CancelFunc cancel context.CancelFunc
} }
@ -493,6 +494,8 @@ func (ac *AutoConfig) generateCSR() (csr string, key string, err error) {
// config data to be used during a call to ReadConfig, updating the // config data to be used during a call to ReadConfig, updating the
// tls Configurator and prepopulating the cache. // tls Configurator and prepopulating the cache.
func (ac *AutoConfig) update(resp *pbautoconf.AutoConfigResponse) error { func (ac *AutoConfig) update(resp *pbautoconf.AutoConfigResponse) error {
ac.autoConfigResponse = resp
if err := ac.updateConfigFromResponse(resp); err != nil { if err := ac.updateConfigFromResponse(resp); err != nil {
return err return err
} }
@ -591,3 +594,18 @@ func (ac *AutoConfig) FallbackTLS(ctx context.Context) (*structs.SignedResponse,
return extractSignedResponse(resp) return extractSignedResponse(resp)
} }
func (ac *AutoConfig) RecordUpdatedCerts(resp *structs.SignedResponse) error {
var err error
ac.autoConfigResponse.ExtraCACertificates = resp.ManualCARoots
ac.autoConfigResponse.CARoots, err = translateCARootsToProtobuf(&resp.ConnectCARoots)
if err != nil {
return err
}
ac.autoConfigResponse.Certificate, err = translateIssuedCertToProtobuf(&resp.IssuedCert)
if err != nil {
return err
}
return ac.recordResponse(ac.autoConfigResponse)
}

View File

@ -226,3 +226,34 @@ func mapstructureTranslateToStructs(in interface{}, out interface{}) error {
return decoder.Decode(in) return decoder.Decode(in)
} }
func translateCARootsToProtobuf(in *structs.IndexedCARoots) (*pbconnect.CARoots, error) {
var out pbconnect.CARoots
if err := mapstructureTranslateToProtobuf(in, &out); err != nil {
return nil, fmt.Errorf("Failed to re-encode CA Roots: %w", err)
}
return &out, nil
}
func translateIssuedCertToProtobuf(in *structs.IssuedCert) (*pbconnect.IssuedCert, error) {
var out pbconnect.IssuedCert
if err := mapstructureTranslateToProtobuf(in, &out); err != nil {
return nil, fmt.Errorf("Failed to re-encode CA Roots: %w", err)
}
return &out, nil
}
func mapstructureTranslateToProtobuf(in interface{}, out interface{}) error {
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
DecodeHook: proto.HookTimeToPBTimestamp,
Result: out,
})
if err != nil {
return err
}
return decoder.Decode(in)
}

View File

@ -40,6 +40,7 @@ type CertMonitor struct {
tokens *token.Store tokens *token.Store
leafReq cachetype.ConnectCALeafRequest leafReq cachetype.ConnectCALeafRequest
rootsReq structs.DCSpecificRequest rootsReq structs.DCSpecificRequest
persist PersistFunc
fallback FallbackFunc fallback FallbackFunc
fallbackLeeway time.Duration fallbackLeeway time.Duration
fallbackRetry time.Duration fallbackRetry time.Duration
@ -66,6 +67,11 @@ type CertMonitor struct {
// events from the token store when the Agent // events from the token store when the Agent
// token is updated. // token is updated.
tokenUpdates token.Notifier tokenUpdates token.Notifier
// this is used to keep a local copy of the certs
// keys and ca certs. It will be used to persist
// all of the local state at once.
certs structs.SignedResponse
} }
// New creates a new CertMonitor for automatically rotating // New creates a new CertMonitor for automatically rotating
@ -115,6 +121,7 @@ func New(config *Config) (*CertMonitor, error) {
cache: config.Cache, cache: config.Cache,
tokens: config.Tokens, tokens: config.Tokens,
tlsConfigurator: config.TLSConfigurator, tlsConfigurator: config.TLSConfigurator,
persist: config.Persist,
fallback: config.Fallback, fallback: config.Fallback,
fallbackLeeway: config.FallbackLeeway, fallbackLeeway: config.FallbackLeeway,
fallbackRetry: config.FallbackRetry, fallbackRetry: config.FallbackRetry,
@ -135,6 +142,8 @@ func (m *CertMonitor) Update(certs *structs.SignedResponse) error {
return nil return nil
} }
m.certs = *certs
if err := m.populateCache(certs); err != nil { if err := m.populateCache(certs); err != nil {
return fmt.Errorf("error populating cache with certificates: %w", err) return fmt.Errorf("error populating cache with certificates: %w", err)
} }
@ -306,6 +315,8 @@ func (m *CertMonitor) handleCacheEvent(u cache.UpdateEvent) error {
return fmt.Errorf("invalid type for roots watch response: %T", u.Result) return fmt.Errorf("invalid type for roots watch response: %T", u.Result)
} }
m.certs.ConnectCARoots = *roots
var pems []string var pems []string
for _, root := range roots.Roots { for _, root := range roots.Roots {
pems = append(pems, root.RootCert) pems = append(pems, root.RootCert)
@ -314,6 +325,13 @@ func (m *CertMonitor) handleCacheEvent(u cache.UpdateEvent) error {
if err := m.tlsConfigurator.UpdateAutoTLSCA(pems); err != nil { if err := m.tlsConfigurator.UpdateAutoTLSCA(pems); err != nil {
return fmt.Errorf("failed to update Connect CA certificates: %w", err) return fmt.Errorf("failed to update Connect CA certificates: %w", err)
} }
if m.persist != nil {
copy := m.certs
if err := m.persist(&copy); err != nil {
return fmt.Errorf("failed to persist certificate package: %w", err)
}
}
case leafWatchID: case leafWatchID:
m.logger.Debug("leaf certificate watch fired - updating TLS certificate") m.logger.Debug("leaf certificate watch fired - updating TLS certificate")
if u.Err != nil { if u.Err != nil {
@ -324,9 +342,19 @@ func (m *CertMonitor) handleCacheEvent(u cache.UpdateEvent) error {
if !ok { if !ok {
return fmt.Errorf("invalid type for agent leaf cert watch response: %T", u.Result) return fmt.Errorf("invalid type for agent leaf cert watch response: %T", u.Result)
} }
m.certs.IssuedCert = *leaf
if err := m.tlsConfigurator.UpdateAutoTLSCert(leaf.CertPEM, leaf.PrivateKeyPEM); err != nil { if err := m.tlsConfigurator.UpdateAutoTLSCert(leaf.CertPEM, leaf.PrivateKeyPEM); err != nil {
return fmt.Errorf("failed to update the agent leaf cert: %w", err) return fmt.Errorf("failed to update the agent leaf cert: %w", err)
} }
if m.persist != nil {
copy := m.certs
if err := m.persist(&copy); err != nil {
return fmt.Errorf("failed to persist certificate package: %w", err)
}
}
} }
return nil return nil
@ -380,6 +408,11 @@ func (m *CertMonitor) handleFallback(ctx context.Context) error {
return fmt.Errorf("error when getting new agent certificate: %w", err) return fmt.Errorf("error when getting new agent certificate: %w", err)
} }
if m.persist != nil {
if err := m.persist(reply); err != nil {
return fmt.Errorf("failed to persist certificate package: %w", err)
}
}
return m.Update(reply) return m.Update(reply)
} }

View File

@ -33,6 +33,14 @@ func (m *mockFallback) fallback(ctx context.Context) (*structs.SignedResponse, e
return resp, ret.Error(1) return resp, ret.Error(1)
} }
type mockPersist struct {
mock.Mock
}
func (m *mockPersist) persist(resp *structs.SignedResponse) error {
return m.Called(resp).Error(0)
}
type mockWatcher struct { type mockWatcher struct {
ch chan<- cache.UpdateEvent ch chan<- cache.UpdateEvent
done <-chan struct{} done <-chan struct{}
@ -159,6 +167,7 @@ type testCertMonitor struct {
tls *tlsutil.Configurator tls *tlsutil.Configurator
tokens *token.Store tokens *token.Store
fallback *mockFallback fallback *mockFallback
persist *mockPersist
extraCACerts []string extraCACerts []string
initialCert *structs.IssuedCert initialCert *structs.IssuedCert
@ -210,8 +219,10 @@ func newTestCertMonitor(t *testing.T) testCertMonitor {
dnsSANs := []string{"test.dev"} dnsSANs := []string{"test.dev"}
ipSANs := []net.IP{net.IPv4(198, 18, 0, 1)} ipSANs := []net.IP{net.IPv4(198, 18, 0, 1)}
// this chan should be unbuffered so we can detect when the fallback func has been called.
fallback := &mockFallback{} fallback := &mockFallback{}
fallback.Test(t)
persist := &mockPersist{}
persist.Test(t)
mcache := newMockCache(t) mcache := newMockCache(t)
rootRes := cache.FetchResult{Value: &indexedRoots, Index: 1} rootRes := cache.FetchResult{Value: &indexedRoots, Index: 1}
@ -246,7 +257,8 @@ func newTestCertMonitor(t *testing.T) testCertMonitor {
WithDatacenter("foo"). WithDatacenter("foo").
WithNodeName("node"). WithNodeName("node").
WithFallbackLeeway(time.Nanosecond). WithFallbackLeeway(time.Nanosecond).
WithFallbackRetry(time.Millisecond) WithFallbackRetry(time.Millisecond).
WithPersistence(persist.persist)
monitor, err := New(cfg) monitor, err := New(cfg)
require.NoError(t, err) require.NoError(t, err)
@ -259,6 +271,7 @@ func newTestCertMonitor(t *testing.T) testCertMonitor {
tls: tlsConfigurator, tls: tlsConfigurator,
tokens: tokens, tokens: tokens,
mcache: mcache, mcache: mcache,
persist: persist,
fallback: fallback, fallback: fallback,
extraCACerts: []string{manualCA.RootCert}, extraCACerts: []string{manualCA.RootCert},
initialCert: issued, initialCert: issued,
@ -298,6 +311,7 @@ func (cm *testCertMonitor) initialCACerts() []string {
func (cm *testCertMonitor) assertExpectations(t *testing.T) { func (cm *testCertMonitor) assertExpectations(t *testing.T) {
cm.mcache.AssertExpectations(t) cm.mcache.AssertExpectations(t)
cm.fallback.AssertExpectations(t) cm.fallback.AssertExpectations(t)
cm.persist.AssertExpectations(t)
} }
func TestCertMonitor_InitialCerts(t *testing.T) { func TestCertMonitor_InitialCerts(t *testing.T) {
@ -473,6 +487,13 @@ func TestCertMonitor_RootsUpdate(t *testing.T) {
}, },
} }
cm.persist.On("persist", &structs.SignedResponse{
IssuedCert: *cm.initialCert,
ManualCARoots: cm.extraCACerts,
ConnectCARoots: secondRoots,
VerifyServerHostname: cm.verifyServerHostname,
}).Return(nil).Once()
// assert value of the CA certs prior to updating // assert value of the CA certs prior to updating
require.ElementsMatch(t, cm.initialCACerts(), cm.tls.CAPems()) require.ElementsMatch(t, cm.initialCACerts(), cm.tls.CAPems())
@ -500,6 +521,13 @@ func TestCertMonitor_CertUpdate(t *testing.T) {
secondCert := newLeaf(t, cm.initialRoots.Roots[0], 100, 10*time.Minute) secondCert := newLeaf(t, cm.initialRoots.Roots[0], 100, 10*time.Minute)
cm.persist.On("persist", &structs.SignedResponse{
IssuedCert: *secondCert,
ManualCARoots: cm.extraCACerts,
ConnectCARoots: *cm.initialRoots,
VerifyServerHostname: cm.verifyServerHostname,
}).Return(nil).Once()
// assert value of cert prior to updating the leaf // assert value of cert prior to updating the leaf
require.Equal(t, cm.initialTLSCertificate(t), cm.tls.Cert()) require.Equal(t, cm.initialTLSCertificate(t), cm.tls.Cert())
@ -549,13 +577,23 @@ func TestCertMonitor_Fallback(t *testing.T) {
// inject a fallback routine error to check that we rerun it quickly // inject a fallback routine error to check that we rerun it quickly
cm.fallback.On("fallback").Return(nil, fmt.Errorf("induced error")).Once() cm.fallback.On("fallback").Return(nil, fmt.Errorf("induced error")).Once()
// expect the fallback routine to be executed and setup the return fallbackResp := &structs.SignedResponse{
cm.fallback.On("fallback").Return(&structs.SignedResponse{
ConnectCARoots: secondRoots, ConnectCARoots: secondRoots,
IssuedCert: *thirdCert, IssuedCert: *thirdCert,
ManualCARoots: cm.extraCACerts, ManualCARoots: cm.extraCACerts,
VerifyServerHostname: true, VerifyServerHostname: true,
}, nil).Once() }
// expect the fallback routine to be executed and setup the return
cm.fallback.On("fallback").Return(fallbackResp, nil).Once()
cm.persist.On("persist", &structs.SignedResponse{
IssuedCert: *secondCert,
ConnectCARoots: *cm.initialRoots,
ManualCARoots: cm.extraCACerts,
VerifyServerHostname: cm.verifyServerHostname,
}).Return(nil).Once()
cm.persist.On("persist", fallbackResp).Return(nil).Once()
// Add another roots cache prepopulation expectation which should happen // Add another roots cache prepopulation expectation which should happen
// in response to executing the fallback mechanism // in response to executing the fallback mechanism

View File

@ -16,6 +16,9 @@ import (
// method of updating the certificate is required. // method of updating the certificate is required.
type FallbackFunc func(context.Context) (*structs.SignedResponse, error) type FallbackFunc func(context.Context) (*structs.SignedResponse, error)
// PersistFunc is used to persist the data from a signed response
type PersistFunc func(*structs.SignedResponse) error
type Config struct { type Config struct {
// Logger is the logger to be used while running. If not set // Logger is the logger to be used while running. If not set
// then no logging will be performed. // then no logging will be performed.
@ -34,6 +37,9 @@ type Config struct {
// This field is required. // This field is required.
Tokens *token.Store Tokens *token.Store
// Persist is a function to run when there are new certs or keys
Persist PersistFunc
// Fallback is a function to run when the normal cache updating of the // Fallback is a function to run when the normal cache updating of the
// agent's certificates has failed to work for one reason or another. // agent's certificates has failed to work for one reason or another.
// This field is required. // This field is required.
@ -135,3 +141,10 @@ func (cfg *Config) WithFallbackRetry(after time.Duration) *Config {
cfg.FallbackRetry = after cfg.FallbackRetry = after
return cfg return cfg
} }
// WithPersistence will configure the CertMonitor to use this callback for persisting
// a new TLS configuration.
func (cfg *Config) WithPersistence(persist PersistFunc) *Config {
cfg.Persist = persist
return cfg
}