diff --git a/agent/connect/ca/provider_vault.go b/agent/connect/ca/provider_vault.go index 8415d07b9a..60627b5314 100644 --- a/agent/connect/ca/provider_vault.go +++ b/agent/connect/ca/provider_vault.go @@ -518,7 +518,7 @@ func (v *VaultProvider) ActiveLeafSigningCert() (string, error) { // because the endpoint only returns the raw PEM contents of the CA cert // and not the typical format of the secrets endpoints. func (v *VaultProvider) getCA(namespace, path string) (string, error) { - resp, err := v.client.WithNamespace(namespace).Logical().ReadRaw(path + "/ca/pem") + resp, err := v.client.WithNamespace(v.getNamespace(namespace)).Logical().ReadRaw(path + "/ca/pem") if resp != nil { defer resp.Body.Close() } @@ -544,7 +544,7 @@ func (v *VaultProvider) getCA(namespace, path string) (string, error) { // TODO: refactor to remove duplication with getCA func (v *VaultProvider) getCAChain(namespace, path string) (string, error) { - resp, err := v.client.WithNamespace(namespace).Logical().ReadRaw(path + "/ca_chain") + resp, err := v.client.WithNamespace(v.getNamespace(namespace)).Logical().ReadRaw(path + "/ca_chain") if resp != nil { defer resp.Body.Close() } @@ -851,27 +851,34 @@ func (v *VaultProvider) Stop() { // We use raw path here func (v *VaultProvider) mountNamespaced(namespace, path string, mountInfo *vaultapi.MountInput) error { - return v.client.WithNamespace(namespace).Sys().Mount(path, mountInfo) + return v.client.WithNamespace(v.getNamespace(namespace)).Sys().Mount(path, mountInfo) } func (v *VaultProvider) tuneMountNamespaced(namespace, path string, mountConfig *vaultapi.MountConfigInput) error { - return v.client.WithNamespace(namespace).Sys().TuneMount(path, *mountConfig) + return v.client.WithNamespace(v.getNamespace(namespace)).Sys().TuneMount(path, *mountConfig) } func (v *VaultProvider) unmountNamespaced(namespace, path string) error { - return v.client.WithNamespace(namespace).Sys().Unmount(path) + return v.client.WithNamespace(v.getNamespace(namespace)).Sys().Unmount(path) } func (v *VaultProvider) readNamespaced(namespace string, resource string) (*vaultapi.Secret, error) { - return v.client.WithNamespace(namespace).Logical().Read(resource) + return v.client.WithNamespace(v.getNamespace(namespace)).Logical().Read(resource) } func (v *VaultProvider) writeNamespaced(namespace string, resource string, data map[string]interface{}) (*vaultapi.Secret, error) { - return v.client.WithNamespace(namespace).Logical().Write(resource, data) + return v.client.WithNamespace(v.getNamespace(namespace)).Logical().Write(resource, data) } func (v *VaultProvider) deleteNamespaced(namespace string, resource string) (*vaultapi.Secret, error) { - return v.client.WithNamespace(namespace).Logical().Delete(resource) + return v.client.WithNamespace(v.getNamespace(namespace)).Logical().Delete(resource) +} + +func (v *VaultProvider) getNamespace(namespace string) string { + if namespace != "" { + return namespace + } + return v.baseNamespace } // autotidyIssuers sets Vault's auto-tidy to remove expired issuers diff --git a/agent/connect/ca/provider_vault_test.go b/agent/connect/ca/provider_vault_test.go index f7fb8d178a..ece7659d04 100644 --- a/agent/connect/ca/provider_vault_test.go +++ b/agent/connect/ca/provider_vault_test.go @@ -1415,6 +1415,85 @@ func TestVaultCAProvider_ConsulManaged(t *testing.T) { }) } +func TestVaultCAProvider_EnterpriseNamespace(t *testing.T) { + SkipIfVaultNotPresent(t, vaultRequirements{Enterprise: true}) + t.Parallel() + + cases := map[string]struct { + namespaces map[string]string + }{ + "no configured namespaces": {}, + "only base namespace provided": {namespaces: map[string]string{"Namespace": "base-ns"}}, + "only root namespace provided": {namespaces: map[string]string{"RootPKINamespace": "root-pki-ns"}}, + "only intermediate namespace provided": {namespaces: map[string]string{"IntermediatePKINamespace": "int-pki-ns"}}, + "base and root namespace provided": { + namespaces: map[string]string{ + "Namespace": "base-ns", + "RootPKINamespace": "root-pki-ns", + }, + }, + "base and intermediate namespace provided": { + namespaces: map[string]string{ + "Namespace": "base-ns", + "IntermediatePKINamespace": "int-pki-ns", + }, + }, + "root and intermediate namespace provided": { + namespaces: map[string]string{ + "RootPKINamespace": "root-pki-ns", + "IntermediatePKINamespace": "int-pki-ns", + }, + }, + "all namespaces provided": { + namespaces: map[string]string{ + "Namespace": "base-ns", + "RootPKINamespace": "root-pki-ns", + "IntermediatePKINamespace": "int-pki-ns", + }, + }, + } + + for name, c := range cases { + c := c + t.Run(name, func(t *testing.T) { + t.Parallel() + + testVault := NewTestVaultServer(t) + token := "root" + + providerConfig := map[string]any{ + "RootPKIPath": "pki-root/", + "IntermediatePKIPath": "pki-intermediate/", + } + for k, v := range c.namespaces { + providerConfig[k] = v + } + + if len(c.namespaces) > 0 { + // If explicit namespaces are provided, try to create the provider before any of the namespaces + // have been created. Verify that the provider fails to initialize. + provider, err := createVaultProviderE(t, true, testVault.Addr, token, providerConfig) + require.Error(t, err) + require.NotNil(t, provider) + } + + // Create the namespaces + client := testVault.Client() + client.SetToken(token) + + for _, ns := range c.namespaces { + _, err := client.Logical().Write(fmt.Sprintf("/sys/namespaces/%s", ns), map[string]any{}) + require.NoError(t, err) + } + + // Verify that once the namespaces have been created we are able to initialize the provider. + provider, err := createVaultProviderE(t, true, testVault.Addr, token, providerConfig) + require.NoError(t, err) + require.NotNil(t, provider) + }) + } +} + func getIntermediateCertTTL(t *testing.T, caConf *structs.CAConfiguration) time.Duration { t.Helper() @@ -1436,6 +1515,15 @@ func getIntermediateCertTTL(t *testing.T, caConf *structs.CAConfiguration) time. func createVaultProvider(t *testing.T, isPrimary bool, addr, token string, rawConf map[string]any) *VaultProvider { t.Helper() + + provider, err := createVaultProviderE(t, isPrimary, addr, token, rawConf) + require.NoError(t, err) + + return provider +} + +func createVaultProviderE(t *testing.T, isPrimary bool, addr, token string, rawConf map[string]any) (*VaultProvider, error) { + t.Helper() cfg := vaultProviderConfig(t, addr, token, rawConf) provider := NewVaultProvider(hclog.New(nil)) @@ -1446,15 +1534,19 @@ func createVaultProvider(t *testing.T, isPrimary bool, addr, token string, rawCo } t.Cleanup(provider.Stop) - require.NoError(t, provider.Configure(cfg)) + if err := provider.Configure(cfg); err != nil { + return provider, err + } if isPrimary { - _, err := provider.GenerateCAChain() - require.NoError(t, err) - _, err = provider.GenerateLeafSigningCert() - require.NoError(t, err) + if _, err := provider.GenerateCAChain(); err != nil { + return provider, err + } + if _, err := provider.GenerateLeafSigningCert(); err != nil { + return provider, err + } } - return provider + return provider, nil } func vaultProviderConfig(t *testing.T, addr, token string, rawConf map[string]any) ProviderConfig { diff --git a/agent/connect/ca/testing.go b/agent/connect/ca/testing.go index 28d077e34e..7386deea94 100644 --- a/agent/connect/ca/testing.go +++ b/agent/connect/ca/testing.go @@ -61,6 +61,10 @@ type CASigningKeyTypes struct { CSRKeyBits int } +type vaultRequirements struct { + Enterprise bool +} + // CASigningKeyTypeCases returns the cross-product of the important supported CA // key types for generating table tests for CA signing tests (CrossSignCA and // SignIntermediate). @@ -93,7 +97,7 @@ func TestConsulProvider(t testing.T, d ConsulProviderStateDelegate) *ConsulProvi // // These tests may be skipped in CI. They are run as part of a separate // integration test suite. -func SkipIfVaultNotPresent(t testing.T) { +func SkipIfVaultNotPresent(t testing.T, reqs ...vaultRequirements) { // Try to safeguard against tests that will never run in CI. // This substring should match the pattern used by the // test-connect-ca-providers CI job. @@ -110,6 +114,16 @@ func SkipIfVaultNotPresent(t testing.T) { if err != nil || path == "" { t.Skipf("%q not found on $PATH - download and install to run this test", vaultBinaryName) } + + // Check for any additional Vault requirements. + for _, r := range reqs { + if r.Enterprise { + ver := vaultVersion(t, vaultBinaryName) + if !strings.Contains(ver, "+ent") { + t.Skipf("%q is not a Vault Enterprise version", ver) + } + } + } } func NewTestVaultServer(t testing.T) *TestVaultServer { @@ -239,8 +253,8 @@ func requireTrailingNewline(t testing.T, leafPEM string) { if len(leafPEM) == 0 { t.Fatalf("cert is empty") } - if '\n' != rune(leafPEM[len(leafPEM)-1]) { - t.Fatalf("cert do not end with a new line") + if rune(leafPEM[len(leafPEM)-1]) != '\n' { + t.Fatalf("cert does not end with a new line") } } @@ -367,3 +381,10 @@ func createVaultTokenAndPolicy(t testing.T, client *vaultapi.Client, policyName, require.NoError(t, err) return tok.Auth.ClientToken } + +func vaultVersion(t testing.T, vaultBinaryName string) string { + cmd := exec.Command(vaultBinaryName, []string{"version"}...) + output, err := cmd.Output() + require.NoError(t, err) + return string(output[:len(output)-1]) +}