From 2ec94b027e4ee73ad542b6ffcd2b797a96f6ef2e Mon Sep 17 00:00:00 2001 From: Kyle Havlovitz Date: Wed, 30 Sep 2020 12:31:21 -0700 Subject: [PATCH] connect: Enable renewing the intermediate cert in the primary DC --- agent/connect/ca/provider.go | 7 ++ agent/connect/ca/provider_vault.go | 23 ++-- agent/connect/ca/provider_vault_test.go | 131 +--------------------- agent/connect/ca/testing.go | 142 ++++++++++++++++++++++++ agent/consul/leader_connect.go | 59 ++++++++-- agent/consul/leader_connect_test.go | 113 ++++++++++++++++++- agent/consul/server.go | 2 +- 7 files changed, 330 insertions(+), 147 deletions(-) diff --git a/agent/connect/ca/provider.go b/agent/connect/ca/provider.go index f9e637c419..1dd1408cf2 100644 --- a/agent/connect/ca/provider.go +++ b/agent/connect/ca/provider.go @@ -16,6 +16,13 @@ import ( // on servers and CA provider. var ErrRateLimited = errors.New("operation rate limited by CA provider") +// PrimaryIntermediateProviders is a list of CA providers that make use use of an +// intermediate cert in the primary datacenter as well as the secondary. This is used +// when determining whether to run the intermediate renewal routine in the primary. +var PrimaryIntermediateProviders = map[string]struct{}{ + "vault": struct{}{}, +} + // ProviderConfig encapsulates all the data Consul passes to `Configure` on a // new provider instance. The provider must treat this as read-only and make // copies of any map or slice if it might modify them internally. diff --git a/agent/connect/ca/provider_vault.go b/agent/connect/ca/provider_vault.go index d3a9ac7146..7c5fe47c8b 100644 --- a/agent/connect/ca/provider_vault.go +++ b/agent/connect/ca/provider_vault.go @@ -92,7 +92,7 @@ func (v *VaultProvider) Configure(cfg ProviderConfig) error { // Set up a renewer to renew the token automatically, if supported. if token.Renewable { - lifetimeWatcher, err := client.NewLifetimeWatcher(&vaultapi.LifetimeWatcherInput{ + renewer, err := client.NewRenewer(&vaultapi.RenewerInput{ Secret: &vaultapi.Secret{ Auth: &vaultapi.SecretAuth{ ClientToken: config.Token, @@ -100,8 +100,7 @@ func (v *VaultProvider) Configure(cfg ProviderConfig) error { LeaseDuration: secret.LeaseDuration, }, }, - Increment: token.TTL, - RenewBehavior: vaultapi.RenewBehaviorIgnoreErrors, + Increment: token.TTL, }) if err != nil { return fmt.Errorf("Error beginning Vault provider token renewal: %v", err) @@ -109,31 +108,31 @@ func (v *VaultProvider) Configure(cfg ProviderConfig) error { ctx, cancel := context.WithCancel(context.TODO()) v.shutdown = cancel - go v.renewToken(ctx, lifetimeWatcher) + go v.renewToken(ctx, renewer) } return nil } // renewToken uses a vaultapi.Renewer to repeatedly renew our token's lease. -func (v *VaultProvider) renewToken(ctx context.Context, watcher *vaultapi.LifetimeWatcher) { - go watcher.Start() - defer watcher.Stop() +func (v *VaultProvider) renewToken(ctx context.Context, renewer *vaultapi.Renewer) { + go renewer.Renew() + defer renewer.Stop() for { select { case <-ctx.Done(): return - case err := <-watcher.DoneCh(): + case err := <-renewer.DoneCh(): if err != nil { v.logger.Error("Error renewing token for Vault provider", "error", err) } - // Watcher routine has finished, so start it again. - go watcher.Start() + // Renewer routine has finished, so start it again. + go renewer.Renew() - case <-watcher.RenewCh(): + case <-renewer.RenewCh(): v.logger.Error("Successfully renewed token for Vault provider") } } @@ -384,6 +383,7 @@ func (v *VaultProvider) GenerateIntermediate() (string, error) { "csr": csr, "use_csr_values": true, "format": "pem_bundle", + "ttl": v.config.IntermediateCertTTL.String(), }) if err != nil { return "", err @@ -456,6 +456,7 @@ func (v *VaultProvider) SignIntermediate(csr *x509.CertificateRequest) (string, "use_csr_values": true, "format": "pem_bundle", "max_path_length": 0, + "ttl": v.config.IntermediateCertTTL.String(), }) if err != nil { return "", err diff --git a/agent/connect/ca/provider_vault_test.go b/agent/connect/ca/provider_vault_test.go index 3094cb092f..b1890e08eb 100644 --- a/agent/connect/ca/provider_vault_test.go +++ b/agent/connect/ca/provider_vault_test.go @@ -7,13 +7,11 @@ import ( "io/ioutil" "os" "os/exec" - "sync" "testing" "time" "github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/structs" - "github.com/hashicorp/consul/sdk/freeport" "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/go-hclog" vaultapi "github.com/hashicorp/vault/api" @@ -70,7 +68,7 @@ func TestVaultCAProvider_RenewToken(t *testing.T) { require.NoError(t, err) providerToken := secret.Auth.ClientToken - _, err = createVaultProvider(t, true, testVault.addr, providerToken, nil) + _, err = createVaultProvider(t, true, testVault.Addr, providerToken, nil) require.NoError(t, err) // Check the last renewal time. @@ -382,11 +380,11 @@ func getIntermediateCertTTL(t *testing.T, caConf *structs.CAConfiguration) time. return dur } -func testVaultProvider(t *testing.T) (*VaultProvider, *testVaultServer) { +func testVaultProvider(t *testing.T) (*VaultProvider, *TestVaultServer) { return testVaultProviderWithConfig(t, true, nil) } -func testVaultProviderWithConfig(t *testing.T, isPrimary bool, rawConf map[string]interface{}) (*VaultProvider, *testVaultServer) { +func testVaultProviderWithConfig(t *testing.T, isPrimary bool, rawConf map[string]interface{}) (*VaultProvider, *TestVaultServer) { testVault, err := runTestVault(t) if err != nil { t.Fatalf("err: %v", err) @@ -394,7 +392,7 @@ func testVaultProviderWithConfig(t *testing.T, isPrimary bool, rawConf map[strin testVault.WaitUntilReady(t) - provider, err := createVaultProvider(t, isPrimary, testVault.addr, testVault.rootToken, rawConf) + provider, err := createVaultProvider(t, isPrimary, testVault.Addr, testVault.RootToken, rawConf) if err != nil { testVault.Stop() t.Fatalf("err: %v", err) @@ -459,124 +457,3 @@ func skipIfVaultNotPresent(t *testing.T) { t.Skipf("%q not found on $PATH - download and install to run this test", vaultBinaryName) } } - -func runTestVault(t *testing.T) (*testVaultServer, error) { - vaultBinaryName := os.Getenv("VAULT_BINARY_NAME") - if vaultBinaryName == "" { - vaultBinaryName = "vault" - } - - path, err := exec.LookPath(vaultBinaryName) - if err != nil || path == "" { - return nil, fmt.Errorf("%q not found on $PATH", vaultBinaryName) - } - - ports := freeport.MustTake(2) - returnPortsFn := func() { - freeport.Return(ports) - } - - var ( - clientAddr = fmt.Sprintf("127.0.0.1:%d", ports[0]) - clusterAddr = fmt.Sprintf("127.0.0.1:%d", ports[1]) - ) - - const token = "root" - - client, err := vaultapi.NewClient(&vaultapi.Config{ - Address: "http://" + clientAddr, - }) - if err != nil { - returnPortsFn() - return nil, err - } - client.SetToken(token) - - args := []string{ - "server", - "-dev", - "-dev-root-token-id", - token, - "-dev-listen-address", - clientAddr, - "-address", - clusterAddr, - } - - cmd := exec.Command(vaultBinaryName, args...) - cmd.Stdout = ioutil.Discard - cmd.Stderr = ioutil.Discard - if err := cmd.Start(); err != nil { - returnPortsFn() - return nil, err - } - - testVault := &testVaultServer{ - rootToken: token, - addr: "http://" + clientAddr, - cmd: cmd, - client: client, - returnPortsFn: returnPortsFn, - } - t.Cleanup(func() { - testVault.Stop() - }) - return testVault, nil -} - -type testVaultServer struct { - rootToken string - addr string - cmd *exec.Cmd - client *vaultapi.Client - - // returnPortsFn will put the ports claimed for the test back into the - returnPortsFn func() -} - -var printedVaultVersion sync.Once - -func (v *testVaultServer) WaitUntilReady(t *testing.T) { - var version string - retry.Run(t, func(r *retry.R) { - resp, err := v.client.Sys().Health() - if err != nil { - r.Fatalf("err: %v", err) - } - if !resp.Initialized { - r.Fatalf("vault server is not initialized") - } - if resp.Sealed { - r.Fatalf("vault server is sealed") - } - version = resp.Version - }) - printedVaultVersion.Do(func() { - fmt.Fprintf(os.Stderr, "[INFO] agent/connect/ca: testing with vault server version: %s\n", version) - }) -} - -func (v *testVaultServer) Stop() error { - // There was no process - if v.cmd == nil { - return nil - } - - if v.cmd.Process != nil { - if err := v.cmd.Process.Signal(os.Interrupt); err != nil { - return fmt.Errorf("failed to kill vault server: %v", err) - } - } - - // wait for the process to exit to be sure that the data dir can be - // deleted on all platforms. - if err := v.cmd.Wait(); err != nil { - return err - } - - if v.returnPortsFn != nil { - v.returnPortsFn() - } - - return nil -} diff --git a/agent/connect/ca/testing.go b/agent/connect/ca/testing.go index 9a637843e0..81995275fe 100644 --- a/agent/connect/ca/testing.go +++ b/agent/connect/ca/testing.go @@ -3,9 +3,15 @@ package ca import ( "fmt" "io/ioutil" + "os" + "os/exec" + "sync" "github.com/hashicorp/consul/agent/connect" + "github.com/hashicorp/consul/sdk/freeport" + "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/go-hclog" + vaultapi "github.com/hashicorp/vault/api" "github.com/mitchellh/go-testing-interface" ) @@ -76,3 +82,139 @@ func TestConsulProvider(t testing.T, d ConsulProviderStateDelegate) *ConsulProvi provider.SetLogger(logger) return provider } + +func NewTestVaultServer(t testing.T) *TestVaultServer { + testVault, err := runTestVault(t) + if err != nil { + t.Fatalf("err: %v", err) + } + + testVault.WaitUntilReady(t) + + return testVault +} + +func runTestVault(t testing.T) (*TestVaultServer, error) { + vaultBinaryName := os.Getenv("VAULT_BINARY_NAME") + if vaultBinaryName == "" { + vaultBinaryName = "vault" + } + + path, err := exec.LookPath(vaultBinaryName) + if err != nil || path == "" { + return nil, fmt.Errorf("%q not found on $PATH", vaultBinaryName) + } + + ports := freeport.MustTake(2) + returnPortsFn := func() { + freeport.Return(ports) + } + + var ( + clientAddr = fmt.Sprintf("127.0.0.1:%d", ports[0]) + clusterAddr = fmt.Sprintf("127.0.0.1:%d", ports[1]) + ) + + const token = "root" + + client, err := vaultapi.NewClient(&vaultapi.Config{ + Address: "http://" + clientAddr, + }) + if err != nil { + returnPortsFn() + return nil, err + } + client.SetToken(token) + + args := []string{ + "server", + "-dev", + "-dev-root-token-id", + token, + "-dev-listen-address", + clientAddr, + "-address", + clusterAddr, + } + + cmd := exec.Command(vaultBinaryName, args...) + cmd.Stdout = ioutil.Discard + cmd.Stderr = ioutil.Discard + if err := cmd.Start(); err != nil { + returnPortsFn() + return nil, err + } + + testVault := &TestVaultServer{ + RootToken: token, + Addr: "http://" + clientAddr, + cmd: cmd, + client: client, + returnPortsFn: returnPortsFn, + } + t.Cleanup(func() { + testVault.Stop() + }) + return testVault, nil +} + +type TestVaultServer struct { + RootToken string + Addr string + cmd *exec.Cmd + client *vaultapi.Client + + // returnPortsFn will put the ports claimed for the test back into the + returnPortsFn func() +} + +var printedVaultVersion sync.Once + +func (v *TestVaultServer) Client() *vaultapi.Client { + return v.client +} + +func (v *TestVaultServer) WaitUntilReady(t testing.T) { + var version string + retry.Run(t, func(r *retry.R) { + resp, err := v.client.Sys().Health() + if err != nil { + r.Fatalf("err: %v", err) + } + if !resp.Initialized { + r.Fatalf("vault server is not initialized") + } + if resp.Sealed { + r.Fatalf("vault server is sealed") + } + version = resp.Version + }) + printedVaultVersion.Do(func() { + fmt.Fprintf(os.Stderr, "[INFO] agent/connect/ca: testing with vault server version: %s\n", version) + }) +} + +func (v *TestVaultServer) Stop() error { + // There was no process + if v.cmd == nil { + return nil + } + + if v.cmd.Process != nil { + if err := v.cmd.Process.Signal(os.Interrupt); err != nil { + return fmt.Errorf("failed to kill vault server: %v", err) + } + } + + // wait for the process to exit to be sure that the data dir can be + // deleted on all platforms. + if err := v.cmd.Wait(); err != nil { + return err + } + + if v.returnPortsFn != nil { + v.returnPortsFn() + } + + return nil +} diff --git a/agent/consul/leader_connect.go b/agent/consul/leader_connect.go index d512cc2760..6befad04dd 100644 --- a/agent/consul/leader_connect.go +++ b/agent/consul/leader_connect.go @@ -510,6 +510,31 @@ func (s *Server) persistNewRoot(provider ca.Provider, newActiveRoot *structs.CAR return nil } +// getIntermediateCAPrimary regenerates the intermediate cert in the primary datacenter. +// This is only run for CAs that require an intermediary in the primary DC, such as Vault. +// This function is being called while holding caProviderReconfigurationLock +// which means it must never take that lock itself or call anything that does. +func (s *Server) getIntermediateCAPrimary(provider ca.Provider, newActiveRoot *structs.CARoot) error { + connectLogger := s.loggers.Named(logging.Connect) + intermediatePEM, err := provider.GenerateIntermediate() + if err != nil { + return fmt.Errorf("error generating new intermediate cert: %v", err) + } + + intermediateCert, err := connect.ParseCert(intermediatePEM) + if err != nil { + return fmt.Errorf("error parsing intermediate cert: %v", err) + } + + // Append the new intermediate to our local active root entry. This is + // where the root representations start to diverge. + newActiveRoot.IntermediateCerts = append(newActiveRoot.IntermediateCerts, intermediatePEM) + newActiveRoot.SigningKeyID = connect.EncodeSigningKeyID(intermediateCert.SubjectKeyId) + + connectLogger.Info("generated new intermediate certificate in primary datacenter") + return nil +} + // getIntermediateCASigned is being called while holding caProviderReconfigurationLock // which means it must never take that lock itself or call anything that does. func (s *Server) getIntermediateCASigned(provider ca.Provider, newActiveRoot *structs.CARoot) error { @@ -558,10 +583,10 @@ func (s *Server) startConnectLeader() { if s.config.ConnectEnabled && s.config.Datacenter != s.config.PrimaryDatacenter { s.leaderRoutineManager.Start(secondaryCARootWatchRoutineName, s.secondaryCARootWatch) s.leaderRoutineManager.Start(intentionReplicationRoutineName, s.replicateIntentions) - s.leaderRoutineManager.Start(secondaryCertRenewWatchRoutineName, s.secondaryIntermediateCertRenewalWatch) s.startConnectLeaderEnterprise() } + s.leaderRoutineManager.Start(intermediateCertRenewWatchRoutineName, s.intermediateCertRenewalWatch) s.leaderRoutineManager.Start(caRootPruningRoutineName, s.runCARootPruning) } @@ -652,11 +677,12 @@ func (s *Server) pruneCARoots() error { return nil } -// secondaryIntermediateCertRenewalWatch checks the intermediate cert for +// intermediateCertRenewalWatch checks the intermediate cert for // expiration. As soon as more than half the time a cert is valid has passed, // it will try to renew it. -func (s *Server) secondaryIntermediateCertRenewalWatch(ctx context.Context) error { +func (s *Server) intermediateCertRenewalWatch(ctx context.Context) error { connectLogger := s.loggers.Named(logging.Connect) + isPrimary := s.config.Datacenter == s.config.PrimaryDatacenter for { select { @@ -672,7 +698,8 @@ func (s *Server) secondaryIntermediateCertRenewalWatch(ctx context.Context) erro // this happens when leadership is being revoked and this go routine will be stopped return nil } - if !s.configuredSecondaryCA() { + // If this isn't the primary, make sure the CA has been initialized. + if !isPrimary && !s.configuredSecondaryCA() { return fmt.Errorf("secondary CA is not yet configured.") } @@ -682,13 +709,26 @@ func (s *Server) secondaryIntermediateCertRenewalWatch(ctx context.Context) erro return err } + // If this is the primary, check if this is a provider that uses an intermediate cert. If + // it isn't, we don't need to check for a renewal. + if isPrimary { + _, config, err := state.CAConfig(nil) + if err != nil { + return err + } + + if _, ok := ca.PrimaryIntermediateProviders[config.Provider]; !ok { + return nil + } + } + activeIntermediate, err := provider.ActiveIntermediate() if err != nil { return err } if activeIntermediate == "" { - return fmt.Errorf("secondary datacenter doesn't have an active intermediate.") + return fmt.Errorf("datacenter doesn't have an active intermediate.") } intermediateCert, err := connect.ParseCert(activeIntermediate) @@ -698,10 +738,15 @@ func (s *Server) secondaryIntermediateCertRenewalWatch(ctx context.Context) erro if lessThanHalfTimePassed(time.Now(), intermediateCert.NotBefore.Add(ca.CertificateTimeDriftBuffer), intermediateCert.NotAfter) { + //connectLogger.Info("checked time passed", intermediateCert.NotBefore.Add(ca.CertificateTimeDriftBuffer), intermediateCert.NotAfter) return nil } - if err := s.getIntermediateCASigned(provider, activeRoot); err != nil { + renewalFunc := s.getIntermediateCAPrimary + if !isPrimary { + renewalFunc = s.getIntermediateCASigned + } + if err := renewalFunc(provider, activeRoot); err != nil { return err } @@ -713,7 +758,7 @@ func (s *Server) secondaryIntermediateCertRenewalWatch(ctx context.Context) erro return nil }, func(err error) { connectLogger.Error("error renewing intermediate certs", - "routine", secondaryCertRenewWatchRoutineName, + "routine", intermediateCertRenewWatchRoutineName, "error", err, ) }) diff --git a/agent/consul/leader_connect_test.go b/agent/consul/leader_connect_test.go index 1f22f27349..8690f3c994 100644 --- a/agent/consul/leader_connect_test.go +++ b/agent/consul/leader_connect_test.go @@ -13,7 +13,7 @@ import ( "time" "github.com/hashicorp/consul/agent/connect" - ca "github.com/hashicorp/consul/agent/connect/ca" + "github.com/hashicorp/consul/agent/connect/ca" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/token" tokenStore "github.com/hashicorp/consul/agent/token" @@ -181,6 +181,117 @@ func getCAProviderWithLock(s *Server) (ca.Provider, *structs.CARoot) { return s.getCAProvider() } +func TestLeader_PrimaryCA_IntermediateRenew(t *testing.T) { + // no parallel execution because we change globals + origInterval := structs.IntermediateCertRenewInterval + origMinTTL := structs.MinLeafCertTTL + defer func() { + structs.IntermediateCertRenewInterval = origInterval + structs.MinLeafCertTTL = origMinTTL + }() + + structs.IntermediateCertRenewInterval = time.Millisecond + structs.MinLeafCertTTL = time.Second + require := require.New(t) + + testVault := ca.NewTestVaultServer(t) + defer testVault.Stop() + + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.Build = "1.6.0" + c.PrimaryDatacenter = "dc1" + c.CAConfig = &structs.CAConfiguration{ + Provider: "vault", + Config: map[string]interface{}{ + "Address": testVault.Addr, + "Token": testVault.RootToken, + "RootPKIPath": "pki-root/", + "IntermediatePKIPath": "pki-intermediate/", + "LeafCertTTL": "5s", + // The retry loop only retries for 7sec max and + // the ttl needs to be below so that it + // triggers definitely. + // Since certs are created so that they are + // valid from 1minute in the past, we need to + // account for that, otherwise it will be + // expired immediately. + "IntermediateCertTTL": "15s", + }, + } + }) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + + testrpc.WaitForLeader(t, s1.RPC, "dc1") + + // Capture the current root + var originalRoot *structs.CARoot + { + rootList, activeRoot, err := getTestRoots(s1, "dc1") + require.NoError(err) + require.Len(rootList.Roots, 1) + originalRoot = activeRoot + } + + // Get the original intermediate + waitForActiveCARoot(t, s1, originalRoot) + provider, _ := getCAProviderWithLock(s1) + intermediatePEM, err := provider.ActiveIntermediate() + require.NoError(err) + cert, err := connect.ParseCert(intermediatePEM) + require.NoError(err) + + // Wait for dc1's intermediate to be refreshed. + // It is possible that test fails when the blocking query doesn't return. + retry.Run(t, func(r *retry.R) { + provider, _ = getCAProviderWithLock(s1) + newIntermediatePEM, err := provider.ActiveIntermediate() + r.Check(err) + _, err = connect.ParseCert(intermediatePEM) + r.Check(err) + if newIntermediatePEM == intermediatePEM { + r.Fatal("not a renewed intermediate") + } + intermediatePEM = newIntermediatePEM + }) + require.NoError(err) + + // Get the root from dc1 and validate a chain of: + // dc1 leaf -> dc1 intermediate -> dc1 root + provider, caRoot := getCAProviderWithLock(s1) + + // Have the new intermediate sign a leaf cert and make sure the chain is correct. + spiffeService := &connect.SpiffeIDService{ + Host: "node1", + Namespace: "default", + Datacenter: "dc1", + Service: "foo", + } + raw, _ := connect.TestCSR(t, spiffeService) + + leafCsr, err := connect.ParseCSR(raw) + require.NoError(err) + + leafPEM, err := provider.Sign(leafCsr) + require.NoError(err) + + cert, err = connect.ParseCert(leafPEM) + require.NoError(err) + + // Check that the leaf signed by the new intermediate can be verified using the + // returned cert chain (signed intermediate + remote root). + intermediatePool := x509.NewCertPool() + intermediatePool.AppendCertsFromPEM([]byte(intermediatePEM)) + rootPool := x509.NewCertPool() + rootPool.AppendCertsFromPEM([]byte(caRoot.RootCert)) + + _, err = cert.Verify(x509.VerifyOptions{ + Intermediates: intermediatePool, + Roots: rootPool, + }) + require.NoError(err) +} + func TestLeader_SecondaryCA_IntermediateRenew(t *testing.T) { // no parallel execution because we change globals origInterval := structs.IntermediateCertRenewInterval diff --git a/agent/consul/server.go b/agent/consul/server.go index a478d0c396..d7c299ad51 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -102,7 +102,7 @@ const ( federationStatePruningRoutineName = "federation state pruning" intentionReplicationRoutineName = "intention replication" secondaryCARootWatchRoutineName = "secondary CA roots watch" - secondaryCertRenewWatchRoutineName = "secondary cert renew watch" + intermediateCertRenewWatchRoutineName = "intermediate cert renew watch" ) var (