// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 package connect import ( "fmt" "testing" "time" "crypto/x509" "encoding/pem" "github.com/stretchr/testify/require" "github.com/hashicorp/consul/agent/structs" ) type KeyConfig struct { keyType string keyBits int } var goodParams = []KeyConfig{ {keyType: "rsa", keyBits: 2048}, {keyType: "rsa", keyBits: 4096}, {keyType: "ec", keyBits: 224}, {keyType: "ec", keyBits: 256}, {keyType: "ec", keyBits: 384}, {keyType: "ec", keyBits: 521}, } var badParams = []KeyConfig{ {keyType: "rsa", keyBits: 0}, {keyType: "rsa", keyBits: 1024}, {keyType: "rsa", keyBits: 24601}, {keyType: "ec", keyBits: 0}, {keyType: "ec", keyBits: 512}, {keyType: "ec", keyBits: 321}, {keyType: "ecdsa", keyBits: 256}, // test for "ecdsa" instead of "ec" {keyType: "aes", keyBits: 128}, } func makeConfig(kc KeyConfig) structs.CommonCAProviderConfig { return structs.CommonCAProviderConfig{ LeafCertTTL: 3 * 24 * time.Hour, IntermediateCertTTL: 365 * 24 * time.Hour, RootCertTTL: 10 * 365 * 24 * time.Hour, PrivateKeyType: kc.keyType, PrivateKeyBits: kc.keyBits, } } func testGenerateRSAKey(t *testing.T, bits int) { _, rsaBlock, err := GeneratePrivateKeyWithConfig("rsa", bits) require.NoError(t, err) require.Contains(t, rsaBlock, "RSA PRIVATE KEY") rsaBytes, _ := pem.Decode([]byte(rsaBlock)) require.NotNil(t, rsaBytes) rsaKey, err := x509.ParsePKCS1PrivateKey(rsaBytes.Bytes) require.NoError(t, err) require.NoError(t, rsaKey.Validate()) require.Equal(t, bits/8, rsaKey.Size()) // note: returned size is in bytes. 2048/8==256 } func testGenerateECDSAKey(t *testing.T, bits int) { _, pemBlock, err := GeneratePrivateKeyWithConfig("ec", bits) require.NoError(t, err) require.Contains(t, pemBlock, "EC PRIVATE KEY") block, _ := pem.Decode([]byte(pemBlock)) require.NotNil(t, block) pk, err := x509.ParseECPrivateKey(block.Bytes) require.NoError(t, err) require.Equal(t, bits, pk.Curve.Params().BitSize) } // Tests to make sure we are able to generate every type of private key supported by the x509 lib. func TestGenerateKeys(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() for _, params := range goodParams { t.Run(fmt.Sprintf("TestGenerateKeys-%s-%d", params.keyType, params.keyBits), func(t *testing.T) { switch params.keyType { case "rsa": testGenerateRSAKey(t, params.keyBits) case "ec": testGenerateECDSAKey(t, params.keyBits) default: t.Fatalf("unknown key type: %s", params.keyType) } }) } } // Tests a variety of valid private key configs to make sure they're accepted. func TestValidateGoodConfigs(t *testing.T) { t.Parallel() for _, params := range goodParams { config := makeConfig(params) t.Run(fmt.Sprintf("TestValidateGoodConfigs-%s-%d", params.keyType, params.keyBits), func(t *testing.T) { require.NoError(t, config.Validate(), "unexpected error: type=%s bits=%d", params.keyType, params.keyBits) }) } } // Tests a variety of invalid private key configs to make sure they're caught. func TestValidateBadConfigs(t *testing.T) { t.Parallel() for _, params := range badParams { config := makeConfig(params) t.Run(fmt.Sprintf("TestValidateBadConfigs-%s-%d", params.keyType, params.keyBits), func(t *testing.T) { require.Error(t, config.Validate(), "expected error: type=%s bits=%d", params.keyType, params.keyBits) }) } } // Tests the ability of a CA to sign a CSR using a different key type. This is // allowed by TLS 1.2 and should succeed in all combinations. func TestSignatureMismatches(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() for _, p1 := range goodParams { for _, p2 := range goodParams { if p1 == p2 { continue } t.Run(fmt.Sprintf("TestMismatches-%s%d-%s%d", p1.keyType, p1.keyBits, p2.keyType, p2.keyBits), func(t *testing.T) { ca := TestCAWithKeyType(t, nil, p1.keyType, p1.keyBits) require.Equal(t, p1.keyType, ca.PrivateKeyType) require.Equal(t, p1.keyBits, ca.PrivateKeyBits) certPEM, keyPEM, err := testLeaf(t, "foobar.service.consul", "default", ca, p2.keyType, p2.keyBits) require.NoError(t, err) _, err = ParseCert(certPEM) require.NoError(t, err) _, err = ParseSigner(keyPEM) require.NoError(t, err) }) } } }