consul/agent/connect/ca/provider_vault_test.go

487 lines
12 KiB
Go

package ca
import (
"crypto/x509"
"fmt"
"io/ioutil"
"os"
"os/exec"
"sync"
"testing"
"time"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/sdk/freeport"
"github.com/hashicorp/consul/sdk/testutil/retry"
vaultapi "github.com/hashicorp/vault/api"
"github.com/stretchr/testify/require"
)
func TestVaultCAProvider_VaultTLSConfig(t *testing.T) {
config := &structs.VaultCAProviderConfig{
CAFile: "/capath/ca.pem",
CAPath: "/capath/",
CertFile: "/certpath/cert.pem",
KeyFile: "/certpath/key.pem",
TLSServerName: "server.name",
TLSSkipVerify: true,
}
tlsConfig := vaultTLSConfig(config)
require := require.New(t)
require.Equal(config.CAFile, tlsConfig.CACert)
require.Equal(config.CAPath, tlsConfig.CAPath)
require.Equal(config.CertFile, tlsConfig.ClientCert)
require.Equal(config.KeyFile, tlsConfig.ClientKey)
require.Equal(config.TLSServerName, tlsConfig.TLSServerName)
require.Equal(config.TLSSkipVerify, tlsConfig.Insecure)
}
func TestVaultCAProvider_Bootstrap(t *testing.T) {
t.Parallel()
if skipIfVaultNotPresent(t) {
return
}
provider, testVault := testVaultProvider(t)
defer testVault.Stop()
client := testVault.client
require := require.New(t)
cases := []struct {
certFunc func() (string, error)
backendPath string
}{
{
certFunc: provider.ActiveRoot,
backendPath: "pki-root/",
},
{
certFunc: provider.ActiveIntermediate,
backendPath: "pki-intermediate/",
},
}
// Verify the root and intermediate certs match the ones in the vault backends
for _, tc := range cases {
cert, err := tc.certFunc()
require.NoError(err)
req := client.NewRequest("GET", "/v1/"+tc.backendPath+"ca/pem")
resp, err := client.RawRequest(req)
require.NoError(err)
bytes, err := ioutil.ReadAll(resp.Body)
require.NoError(err)
require.Equal(cert, string(bytes))
// Should be a valid CA cert
parsed, err := connect.ParseCert(cert)
require.NoError(err)
require.True(parsed.IsCA)
require.Len(parsed.URIs, 1)
require.Equal(parsed.URIs[0].String(), fmt.Sprintf("spiffe://%s.consul", provider.clusterId))
}
}
func assertCorrectKeyType(t *testing.T, want, certPEM string) {
t.Helper()
cert, err := connect.ParseCert(certPEM)
require.NoError(t, err)
switch want {
case "ec":
require.Equal(t, x509.ECDSA, cert.PublicKeyAlgorithm)
case "rsa":
require.Equal(t, x509.RSA, cert.PublicKeyAlgorithm)
default:
t.Fatal("test doesn't support key type")
}
}
func TestVaultCAProvider_SignLeaf(t *testing.T) {
t.Parallel()
if skipIfVaultNotPresent(t) {
return
}
for _, tc := range KeyTestCases {
tc := tc
t.Run(tc.Desc, func(t *testing.T) {
require := require.New(t)
provider, testVault := testVaultProviderWithConfig(t, true, map[string]interface{}{
"LeafCertTTL": "1h",
"PrivateKeyType": tc.KeyType,
"PrivateKeyBits": tc.KeyBits,
})
defer testVault.Stop()
spiffeService := &connect.SpiffeIDService{
Host: "node1",
Namespace: "default",
Datacenter: "dc1",
Service: "foo",
}
rootPEM, err := provider.ActiveRoot()
require.NoError(err)
assertCorrectKeyType(t, tc.KeyType, rootPEM)
intPEM, err := provider.ActiveIntermediate()
require.NoError(err)
assertCorrectKeyType(t, tc.KeyType, intPEM)
// Generate a leaf cert for the service.
var firstSerial uint64
{
raw, _ := connect.TestCSR(t, spiffeService)
csr, err := connect.ParseCSR(raw)
require.NoError(err)
cert, err := provider.Sign(csr)
require.NoError(err)
parsed, err := connect.ParseCert(cert)
require.NoError(err)
require.Equal(parsed.URIs[0], spiffeService.URI())
firstSerial = parsed.SerialNumber.Uint64()
// Ensure the cert is valid now and expires within the correct limit.
now := time.Now()
require.True(parsed.NotAfter.Sub(now) < time.Hour)
require.True(parsed.NotBefore.Before(now))
// Make sure we can validate the cert as expected.
require.NoError(connect.ValidateLeaf(rootPEM, cert, []string{intPEM}))
}
// Generate a new cert for another service and make sure
// the serial number is unique.
spiffeService.Service = "bar"
{
raw, _ := connect.TestCSR(t, spiffeService)
csr, err := connect.ParseCSR(raw)
require.NoError(err)
cert, err := provider.Sign(csr)
require.NoError(err)
parsed, err := connect.ParseCert(cert)
require.NoError(err)
require.Equal(parsed.URIs[0], spiffeService.URI())
require.NotEqual(firstSerial, parsed.SerialNumber.Uint64())
// Ensure the cert is valid now and expires within the correct limit.
require.True(time.Until(parsed.NotAfter) < time.Hour)
require.True(parsed.NotBefore.Before(time.Now()))
// Make sure we can validate the cert as expected.
require.NoError(connect.ValidateLeaf(rootPEM, cert, []string{intPEM}))
}
})
}
}
func TestVaultCAProvider_CrossSignCA(t *testing.T) {
t.Parallel()
if skipIfVaultNotPresent(t) {
return
}
tests := CASigningKeyTypeCases()
for _, tc := range tests {
tc := tc
t.Run(tc.Desc, func(t *testing.T) {
require := require.New(t)
if tc.SigningKeyType != tc.CSRKeyType {
// See https://github.com/hashicorp/vault/issues/7709
t.Skip("Vault doesn't support cross-signing different key types yet.")
}
provider1, testVault1 := testVaultProviderWithConfig(t, true, map[string]interface{}{
"LeafCertTTL": "1h",
"PrivateKeyType": tc.SigningKeyType,
"PrivateKeyBits": tc.SigningKeyBits,
})
defer testVault1.Stop()
{
rootPEM, err := provider1.ActiveRoot()
require.NoError(err)
assertCorrectKeyType(t, tc.SigningKeyType, rootPEM)
intPEM, err := provider1.ActiveIntermediate()
require.NoError(err)
assertCorrectKeyType(t, tc.SigningKeyType, intPEM)
}
provider2, testVault2 := testVaultProviderWithConfig(t, true, map[string]interface{}{
"LeafCertTTL": "1h",
"PrivateKeyType": tc.CSRKeyType,
"PrivateKeyBits": tc.CSRKeyBits,
})
defer testVault2.Stop()
{
rootPEM, err := provider2.ActiveRoot()
require.NoError(err)
assertCorrectKeyType(t, tc.CSRKeyType, rootPEM)
intPEM, err := provider2.ActiveIntermediate()
require.NoError(err)
assertCorrectKeyType(t, tc.CSRKeyType, intPEM)
}
testCrossSignProviders(t, provider1, provider2)
})
}
}
func TestVaultProvider_SignIntermediate(t *testing.T) {
t.Parallel()
if skipIfVaultNotPresent(t) {
return
}
tests := CASigningKeyTypeCases()
for _, tc := range tests {
tc := tc
t.Run(tc.Desc, func(t *testing.T) {
provider1, testVault1 := testVaultProviderWithConfig(t, true, map[string]interface{}{
"LeafCertTTL": "1h",
"PrivateKeyType": tc.SigningKeyType,
"PrivateKeyBits": tc.SigningKeyBits,
})
defer testVault1.Stop()
provider2, testVault2 := testVaultProviderWithConfig(t, false, map[string]interface{}{
"LeafCertTTL": "1h",
"PrivateKeyType": tc.CSRKeyType,
"PrivateKeyBits": tc.CSRKeyBits,
})
defer testVault2.Stop()
testSignIntermediateCrossDC(t, provider1, provider2)
})
}
}
func TestVaultProvider_SignIntermediateConsul(t *testing.T) {
t.Parallel()
if skipIfVaultNotPresent(t) {
return
}
// primary = Vault, secondary = Consul
t.Run("pri=vault,sec=consul", func(t *testing.T) {
provider1, testVault1 := testVaultProviderWithConfig(t, true, nil)
defer testVault1.Stop()
conf := testConsulCAConfig()
delegate := newMockDelegate(t, conf)
provider2 := &ConsulProvider{Delegate: delegate}
require.NoError(t, provider2.Configure(conf.ClusterID, false, conf.Config))
testSignIntermediateCrossDC(t, provider1, provider2)
})
// primary = Consul, secondary = Vault
t.Run("pri=consul,sec=vault", func(t *testing.T) {
conf := testConsulCAConfig()
delegate := newMockDelegate(t, conf)
provider1 := &ConsulProvider{Delegate: delegate}
require.NoError(t, provider1.Configure(conf.ClusterID, true, conf.Config))
require.NoError(t, provider1.GenerateRoot())
provider2, testVault2 := testVaultProviderWithConfig(t, false, nil)
defer testVault2.Stop()
testSignIntermediateCrossDC(t, provider1, provider2)
})
}
func testVaultProvider(t *testing.T) (*VaultProvider, *testVaultServer) {
return testVaultProviderWithConfig(t, true, nil)
}
func testVaultProviderWithConfig(t *testing.T, isRoot bool, rawConf map[string]interface{}) (*VaultProvider, *testVaultServer) {
testVault, err := runTestVault()
if err != nil {
t.Fatalf("err: %v", err)
}
testVault.WaitUntilReady(t)
conf := map[string]interface{}{
"Address": testVault.addr,
"Token": testVault.rootToken,
"RootPKIPath": "pki-root/",
"IntermediatePKIPath": "pki-intermediate/",
// Tests duration parsing after msgpack type mangling during raft apply.
"LeafCertTTL": []uint8("72h"),
}
for k, v := range rawConf {
conf[k] = v
}
provider := &VaultProvider{}
if err := provider.Configure(connect.TestClusterID, isRoot, conf); err != nil {
testVault.Stop()
t.Fatalf("err: %v", err)
}
if isRoot {
if err = provider.GenerateRoot(); err != nil {
testVault.Stop()
t.Fatalf("err: %v", err)
}
if _, err := provider.GenerateIntermediate(); err != nil {
testVault.Stop()
t.Fatalf("err: %v", err)
}
}
return provider, testVault
}
var printedVaultVersion sync.Once
// skipIfVaultNotPresent skips the test and returns true if vault is not found
func skipIfVaultNotPresent(t *testing.T) bool {
vaultBinaryName := os.Getenv("VAULT_BINARY_NAME")
if vaultBinaryName == "" {
vaultBinaryName = "vault"
}
path, err := exec.LookPath(vaultBinaryName)
if err != nil || path == "" {
t.Skipf("%q not found on $PATH - download and install to run this test", vaultBinaryName)
return true
}
return false
}
func runTestVault() (*testVaultServer, error) {
vaultBinaryName := os.Getenv("VAULT_BINARY_NAME")
if vaultBinaryName == "" {
vaultBinaryName = "vault"
}
path, err := exec.LookPath(vaultBinaryName)
if err != nil || path == "" {
return nil, fmt.Errorf("%q not found on $PATH", vaultBinaryName)
}
ports := freeport.MustTake(2)
returnPortsFn := func() {
freeport.Return(ports)
}
var (
clientAddr = fmt.Sprintf("127.0.0.1:%d", ports[0])
clusterAddr = fmt.Sprintf("127.0.0.1:%d", ports[1])
)
const token = "root"
client, err := vaultapi.NewClient(&vaultapi.Config{
Address: "http://" + clientAddr,
})
if err != nil {
returnPortsFn()
return nil, err
}
client.SetToken(token)
args := []string{
"server",
"-dev",
"-dev-root-token-id",
token,
"-dev-listen-address",
clientAddr,
"-address",
clusterAddr,
}
cmd := exec.Command(vaultBinaryName, args...)
cmd.Stdout = ioutil.Discard
cmd.Stderr = ioutil.Discard
if err := cmd.Start(); err != nil {
returnPortsFn()
return nil, err
}
return &testVaultServer{
rootToken: token,
addr: "http://" + clientAddr,
cmd: cmd,
client: client,
returnPortsFn: returnPortsFn,
}, nil
}
type testVaultServer struct {
rootToken string
addr string
cmd *exec.Cmd
client *vaultapi.Client
// returnPortsFn will put the ports claimed for the test back into the
returnPortsFn func()
}
func (v *testVaultServer) WaitUntilReady(t *testing.T) {
var version string
retry.Run(t, func(r *retry.R) {
resp, err := v.client.Sys().Health()
if err != nil {
r.Fatalf("err: %v", err)
}
if !resp.Initialized {
r.Fatalf("vault server is not initialized")
}
if resp.Sealed {
r.Fatalf("vault server is sealed")
}
version = resp.Version
})
printedVaultVersion.Do(func() {
fmt.Fprintf(os.Stderr, "[INFO] agent/connect/ca: testing with vault server version: %s\n", version)
})
}
func (v *testVaultServer) Stop() error {
// There was no process
if v.cmd == nil {
return nil
}
if v.cmd.Process != nil {
if err := v.cmd.Process.Signal(os.Interrupt); err != nil {
return fmt.Errorf("failed to kill vault server: %v", err)
}
}
// wait for the process to exit to be sure that the data dir can be
// deleted on all platforms.
if err := v.cmd.Wait(); err != nil {
return err
}
if v.returnPortsFn != nil {
v.returnPortsFn()
}
return nil
}