Clean up Vault renew tests and shutdown

pull/8560/head
Kyle Havlovitz 2020-09-11 08:41:05 -07:00
parent f40fb577fe
commit 49056fe70f
4 changed files with 49 additions and 53 deletions

View File

@ -2,13 +2,13 @@ package ca
import ( import (
"bytes" "bytes"
"context"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"strings" "strings"
"sync"
"github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
@ -27,9 +27,7 @@ type VaultProvider struct {
config *structs.VaultCAProviderConfig config *structs.VaultCAProviderConfig
client *vaultapi.Client client *vaultapi.Client
shutdown bool shutdown func()
shutdownCh chan struct{}
shutdownLock sync.RWMutex
isPrimary bool isPrimary bool
clusterID string clusterID string
@ -38,6 +36,10 @@ type VaultProvider struct {
logger hclog.Logger logger hclog.Logger
} }
func NewVaultProvider() *VaultProvider {
return &VaultProvider{shutdown: func() {}}
}
func vaultTLSConfig(config *structs.VaultCAProviderConfig) *vaultapi.TLSConfig { func vaultTLSConfig(config *structs.VaultCAProviderConfig) *vaultapi.TLSConfig {
return &vaultapi.TLSConfig{ return &vaultapi.TLSConfig{
CACert: config.CAFile, CACert: config.CAFile,
@ -74,7 +76,6 @@ func (v *VaultProvider) Configure(cfg ProviderConfig) error {
v.isPrimary = cfg.IsPrimary v.isPrimary = cfg.IsPrimary
v.clusterID = cfg.ClusterID v.clusterID = cfg.ClusterID
v.spiffeID = connect.SpiffeIDSigningForCluster(&structs.CAConfiguration{ClusterID: v.clusterID}) v.spiffeID = connect.SpiffeIDSigningForCluster(&structs.CAConfiguration{ClusterID: v.clusterID})
v.shutdownCh = make(chan struct{}, 0)
// Look up the token to see if we can auto-renew its lease. // Look up the token to see if we can auto-renew its lease.
secret, err := client.Auth().Token().Lookup(config.Token) secret, err := client.Auth().Token().Lookup(config.Token)
@ -99,25 +100,28 @@ func (v *VaultProvider) Configure(cfg ProviderConfig) error {
LeaseDuration: secret.LeaseDuration, LeaseDuration: secret.LeaseDuration,
}, },
}, },
Increment: int(token.TTL), Increment: token.TTL,
}) })
if err != nil { if err != nil {
return fmt.Errorf("Error beginning Vault provider token renewal: %v", err) return fmt.Errorf("Error beginning Vault provider token renewal: %v", err)
} }
go v.renewToken(renewer)
ctx, cancel := context.WithCancel(context.TODO())
v.shutdown = cancel
go v.renewToken(ctx, renewer)
} }
return nil return nil
} }
// renewToken uses a vaultapi.Renewer to repeatedly renew our token's lease. // renewToken uses a vaultapi.Renewer to repeatedly renew our token's lease.
func (v *VaultProvider) renewToken(renewer *vaultapi.Renewer) { func (v *VaultProvider) renewToken(ctx context.Context, renewer *vaultapi.Renewer) {
go renewer.Renew() go renewer.Renew()
defer renewer.Stop()
for { for {
select { select {
case <-v.shutdownCh: case <-ctx.Done():
renewer.Stop()
return return
case err := <-renewer.DoneCh(): case err := <-renewer.DoneCh():
@ -125,6 +129,9 @@ func (v *VaultProvider) renewToken(renewer *vaultapi.Renewer) {
v.logger.Error(fmt.Sprintf("Error renewing token for Vault provider: %v", err)) v.logger.Error(fmt.Sprintf("Error renewing token for Vault provider: %v", err))
} }
// Renewer routine has finished, so start it again.
go renewer.Renew()
case <-renewer.RenewCh(): case <-renewer.RenewCh():
v.logger.Error("Successfully renewed token for Vault provider") v.logger.Error("Successfully renewed token for Vault provider")
} }
@ -508,13 +515,7 @@ func (v *VaultProvider) Cleanup() error {
// Stop shuts down the token renew goroutine. // Stop shuts down the token renew goroutine.
func (v *VaultProvider) Stop() { func (v *VaultProvider) Stop() {
v.shutdownLock.Lock() v.shutdown()
defer v.shutdownLock.Unlock()
if !v.shutdown && v.shutdownCh != nil {
close(v.shutdownCh)
v.shutdown = true
}
} }
func ParseVaultCAConfig(raw map[string]interface{}) (*structs.VaultCAProviderConfig, error) { func ParseVaultCAConfig(raw map[string]interface{}) (*structs.VaultCAProviderConfig, error) {

View File

@ -55,14 +55,10 @@ func TestVaultCAProvider_SecondaryActiveIntermediate(t *testing.T) {
func TestVaultCAProvider_RenewToken(t *testing.T) { func TestVaultCAProvider_RenewToken(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
skipIfVaultNotPresent(t) skipIfVaultNotPresent(t)
testVault, err := runTestVault() testVault, err := runTestVault(t)
if err != nil { require.NoError(t, err)
t.Fatalf("err: %v", err)
}
testVault.WaitUntilReady(t) testVault.WaitUntilReady(t)
// Create a token with a short TTL to be renewed by the provider. // Create a token with a short TTL to be renewed by the provider.
@ -71,26 +67,26 @@ func TestVaultCAProvider_RenewToken(t *testing.T) {
TTL: ttl.String(), TTL: ttl.String(),
} }
secret, err := testVault.client.Auth().Token().Create(tcr) secret, err := testVault.client.Auth().Token().Create(tcr)
require.NoError(err) require.NoError(t, err)
providerToken := secret.Auth.ClientToken providerToken := secret.Auth.ClientToken
_, err = createVaultProvider(true, testVault.addr, providerToken, nil) _, err = createVaultProvider(t, true, testVault.addr, providerToken, nil)
require.NoError(err) require.NoError(t, err)
// Check the last renewal time. // Check the last renewal time.
secret, err = testVault.client.Auth().Token().Lookup(providerToken) secret, err = testVault.client.Auth().Token().Lookup(providerToken)
require.NoError(err) require.NoError(t, err)
firstRenewal, err := secret.Data["last_renewal_time"].(json.Number).Int64() firstRenewal, err := secret.Data["last_renewal_time"].(json.Number).Int64()
require.NoError(err) require.NoError(t, err)
time.Sleep(ttl * 2)
// Wait past the TTL and make sure the token has been renewed. // Wait past the TTL and make sure the token has been renewed.
retry.Run(t, func(r *retry.R) {
secret, err = testVault.client.Auth().Token().Lookup(providerToken) secret, err = testVault.client.Auth().Token().Lookup(providerToken)
require.NoError(err) require.NoError(r, err)
lastRenewal, err := secret.Data["last_renewal_time"].(json.Number).Int64() lastRenewal, err := secret.Data["last_renewal_time"].(json.Number).Int64()
require.NoError(err) require.NoError(r, err)
require.Greater(lastRenewal, firstRenewal) require.Greater(r, lastRenewal, firstRenewal)
})
} }
func TestVaultCAProvider_Bootstrap(t *testing.T) { func TestVaultCAProvider_Bootstrap(t *testing.T) {
@ -391,14 +387,14 @@ func testVaultProvider(t *testing.T) (*VaultProvider, *testVaultServer) {
} }
func testVaultProviderWithConfig(t *testing.T, isPrimary bool, rawConf map[string]interface{}) (*VaultProvider, *testVaultServer) { func testVaultProviderWithConfig(t *testing.T, isPrimary bool, rawConf map[string]interface{}) (*VaultProvider, *testVaultServer) {
testVault, err := runTestVault() testVault, err := runTestVault(t)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
testVault.WaitUntilReady(t) testVault.WaitUntilReady(t)
provider, err := createVaultProvider(isPrimary, testVault.addr, testVault.rootToken, rawConf) provider, err := createVaultProvider(t, isPrimary, testVault.addr, testVault.rootToken, rawConf)
if err != nil { if err != nil {
testVault.Stop() testVault.Stop()
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
@ -406,7 +402,7 @@ func testVaultProviderWithConfig(t *testing.T, isPrimary bool, rawConf map[strin
return provider, testVault return provider, testVault
} }
func createVaultProvider(isPrimary bool, addr, token string, rawConf map[string]interface{}) (*VaultProvider, error) { func createVaultProvider(t *testing.T, isPrimary bool, addr, token string, rawConf map[string]interface{}) (*VaultProvider, error) {
conf := map[string]interface{}{ conf := map[string]interface{}{
"Address": addr, "Address": addr,
"Token": token, "Token": token,
@ -419,7 +415,7 @@ func createVaultProvider(isPrimary bool, addr, token string, rawConf map[string]
conf[k] = v conf[k] = v
} }
provider := &VaultProvider{} provider := NewVaultProvider()
cfg := ProviderConfig{ cfg := ProviderConfig{
ClusterID: connect.TestClusterID, ClusterID: connect.TestClusterID,
@ -438,16 +434,11 @@ func createVaultProvider(isPrimary bool, addr, token string, rawConf map[string]
cfg.Datacenter = "dc2" cfg.Datacenter = "dc2"
} }
if err := provider.Configure(cfg); err != nil { require.NoError(t, provider.Configure(cfg))
return nil, err
}
if isPrimary { if isPrimary {
if err := provider.GenerateRoot(); err != nil { require.NoError(t, provider.GenerateRoot())
return nil, err _, err := provider.GenerateIntermediate()
} require.NoError(t, err)
if _, err := provider.GenerateIntermediate(); err != nil {
return nil, err
}
} }
return provider, nil return provider, nil
@ -469,7 +460,7 @@ func skipIfVaultNotPresent(t *testing.T) {
} }
} }
func runTestVault() (*testVaultServer, error) { func runTestVault(t *testing.T) (*testVaultServer, error) {
vaultBinaryName := os.Getenv("VAULT_BINARY_NAME") vaultBinaryName := os.Getenv("VAULT_BINARY_NAME")
if vaultBinaryName == "" { if vaultBinaryName == "" {
vaultBinaryName = "vault" vaultBinaryName = "vault"
@ -520,13 +511,17 @@ func runTestVault() (*testVaultServer, error) {
return nil, err return nil, err
} }
return &testVaultServer{ testVault := &testVaultServer{
rootToken: token, rootToken: token,
addr: "http://" + clientAddr, addr: "http://" + clientAddr,
cmd: cmd, cmd: cmd,
client: client, client: client,
returnPortsFn: returnPortsFn, returnPortsFn: returnPortsFn,
}, nil }
t.Cleanup(func() {
testVault.Stop()
})
return testVault, nil
} }
type testVaultServer struct { type testVaultServer struct {

View File

@ -158,7 +158,7 @@ func (s *ConnectCA) ConfigurationSet(
defer func() { defer func() {
if cleanupNewProvider { if cleanupNewProvider {
if err := newProvider.Cleanup(); err != nil { if err := newProvider.Cleanup(); err != nil {
s.logger.Warn("failed to clean up temporary new CA provider", "provider", newProvider) s.logger.Warn("failed to clean up CA provider while handling startup failure", "provider", newProvider, "error", err)
} }
} }
}() }()

View File

@ -116,7 +116,7 @@ func (s *Server) createCAProvider(conf *structs.CAConfiguration) (ca.Provider, e
case structs.ConsulCAProvider: case structs.ConsulCAProvider:
p = &ca.ConsulProvider{Delegate: &consulCADelegate{s}} p = &ca.ConsulProvider{Delegate: &consulCADelegate{s}}
case structs.VaultCAProvider: case structs.VaultCAProvider:
p = &ca.VaultProvider{} p = ca.NewVaultProvider()
case structs.AWSCAProvider: case structs.AWSCAProvider:
p = &ca.AWSProvider{} p = &ca.AWSProvider{}
default: default: