diff --git a/tlsutil/config.go b/tlsutil/config.go index 69b358fed0..6fcdb1a2e9 100644 --- a/tlsutil/config.go +++ b/tlsutil/config.go @@ -394,19 +394,7 @@ func validateConfig(config Config, pool *x509.CertPool, cert *tls.Certificate) e } func (c Config) anyVerifyIncoming() bool { - return c.baseVerifyIncoming() || c.VerifyIncomingRPC || c.VerifyIncomingHTTPS -} - -func (c Config) verifyIncomingRPC() bool { - return c.baseVerifyIncoming() || c.VerifyIncomingRPC -} - -func (c Config) verifyIncomingHTTPS() bool { - return c.baseVerifyIncoming() || c.VerifyIncomingHTTPS -} - -func (c *Config) baseVerifyIncoming() bool { - return c.VerifyIncoming + return c.VerifyIncoming || c.VerifyIncomingRPC || c.VerifyIncomingHTTPS } func loadKeyPair(certFile, keyFile string) (*tls.Certificate, error) { @@ -540,37 +528,26 @@ func (c *Configurator) Cert() *tls.Certificate { return cert } -// This function acquires a read lock because it reads from the config. +// VerifyIncomingRPC returns true if the configuration has enabled either +// VerifyIncoming, or VerifyIncomingRPC func (c *Configurator) VerifyIncomingRPC() bool { c.lock.RLock() defer c.lock.RUnlock() - return c.base.verifyIncomingRPC() + return c.base.VerifyIncoming || c.base.VerifyIncomingRPC } // This function acquires a read lock because it reads from the config. -func (c *Configurator) outgoingRPCTLSDisabled() bool { +func (c *Configurator) outgoingRPCTLSEnabled() bool { c.lock.RLock() defer c.lock.RUnlock() - // if AutoEncrypt enabled, always use TLS - if c.base.AutoTLS { - return false - } - - // if CAs are provided or VerifyOutgoing is set, use TLS - if c.base.VerifyOutgoing { - return false - } - - return true + // use TLS if AutoEncrypt or VerifyOutgoing are enabled. + return c.base.AutoTLS || c.base.VerifyOutgoing } +// MutualTLSCapable returns true if Configurator has a CA and a local TLS +// certificate configured. func (c *Configurator) MutualTLSCapable() bool { - return c.mutualTLSCapable() -} - -// This function acquires a read lock because it reads from the config. -func (c *Configurator) mutualTLSCapable() bool { c.lock.RLock() defer c.lock.RUnlock() return c.caPool != nil && (c.autoTLS.cert != nil || c.manual.cert != nil) @@ -608,27 +585,6 @@ func (c *Configurator) domain() string { return c.base.Domain } -// This function acquires a read lock because it reads from the config. -func (c *Configurator) verifyIncomingRPC() bool { - c.lock.RLock() - defer c.lock.RUnlock() - return c.base.verifyIncomingRPC() -} - -// This function acquires a read lock because it reads from the config. -func (c *Configurator) verifyIncomingHTTPS() bool { - c.lock.RLock() - defer c.lock.RUnlock() - return c.base.verifyIncomingHTTPS() -} - -// This function acquires a read lock because it reads from the config. -func (c *Configurator) enableAgentTLSForChecks() bool { - c.lock.RLock() - defer c.lock.RUnlock() - return c.base.EnableAgentTLSForChecks -} - // This function acquires a read lock because it reads from the config. func (c *Configurator) serverNameOrNodeName() string { c.lock.RLock() @@ -665,7 +621,7 @@ func (c *Configurator) IncomingGRPCConfig() *tls.Config { // IncomingRPCConfig generates a *tls.Config for incoming RPC connections. func (c *Configurator) IncomingRPCConfig() *tls.Config { c.log("IncomingRPCConfig") - config := c.commonTLSConfig(c.verifyIncomingRPC()) + config := c.commonTLSConfig(c.VerifyIncomingRPC()) config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return c.IncomingRPCConfig(), nil } @@ -705,7 +661,12 @@ func (c *Configurator) IncomingInsecureRPCConfig() *tls.Config { // IncomingHTTPSConfig generates a *tls.Config for incoming HTTPS connections. func (c *Configurator) IncomingHTTPSConfig() *tls.Config { c.log("IncomingHTTPSConfig") - config := c.commonTLSConfig(c.verifyIncomingHTTPS()) + + c.lock.RLock() + verifyIncoming := c.base.VerifyIncoming || c.base.VerifyIncomingHTTPS + c.lock.RUnlock() + + config := c.commonTLSConfig(verifyIncoming) config.NextProtos = []string{"h2", "http/1.1"} config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return c.IncomingHTTPSConfig(), nil @@ -720,7 +681,11 @@ func (c *Configurator) IncomingHTTPSConfig() *tls.Config { func (c *Configurator) OutgoingTLSConfigForCheck(skipVerify bool, serverName string) *tls.Config { c.log("OutgoingTLSConfigForCheck") - if !c.enableAgentTLSForChecks() { + c.lock.RLock() + useAgentTLS := c.base.EnableAgentTLSForChecks + c.lock.RUnlock() + + if !useAgentTLS { return &tls.Config{ InsecureSkipVerify: skipVerify, ServerName: serverName, @@ -742,20 +707,20 @@ func (c *Configurator) OutgoingTLSConfigForCheck(skipVerify bool, serverName str // otherwise we assume that no TLS should be used. func (c *Configurator) OutgoingRPCConfig() *tls.Config { c.log("OutgoingRPCConfig") - if c.outgoingRPCTLSDisabled() { + if !c.outgoingRPCTLSEnabled() { return nil } return c.commonTLSConfig(false) } -// OutgoingALPNRPCConfig generates a *tls.Config for outgoing RPC connections +// outgoingALPNRPCConfig generates a *tls.Config for outgoing RPC connections // directly using TLS with ALPN instead of the older byte-prefixed protocol. // If there is a CA or VerifyOutgoing is set, a *tls.Config will be provided, // otherwise we assume that no TLS should be used which completely disables the // ALPN variation. -func (c *Configurator) OutgoingALPNRPCConfig() *tls.Config { - c.log("OutgoingALPNRPCConfig") - if !c.mutualTLSCapable() { +func (c *Configurator) outgoingALPNRPCConfig() *tls.Config { + c.log("outgoingALPNRPCConfig") + if !c.MutualTLSCapable() { return nil // ultimately this will hard-fail as TLS is required } @@ -780,15 +745,17 @@ func (c *Configurator) OutgoingRPCWrapper() DCWrapper { } } +// UseTLS returns true if the outgoing RPC requests have been explicitly configured +// to use TLS (via VerifyOutgoing or AutoTLS, and the target DC supports TLS. func (c *Configurator) UseTLS(dc string) bool { - return !c.outgoingRPCTLSDisabled() && c.getAreaForPeerDatacenterUseTLS(dc) + return c.outgoingRPCTLSEnabled() && c.getAreaForPeerDatacenterUseTLS(dc) } -// OutgoingALPNRPCWrapper wraps the result of OutgoingALPNRPCConfig in an +// OutgoingALPNRPCWrapper wraps the result of outgoingALPNRPCConfig in an // ALPNWrapper. It configures all of the negotiation plumbing. func (c *Configurator) OutgoingALPNRPCWrapper() ALPNWrapper { c.log("OutgoingALPNRPCWrapper") - if !c.mutualTLSCapable() { + if !c.MutualTLSCapable() { return nil } @@ -893,7 +860,7 @@ func (c *Configurator) wrapALPNTLSClient(dc, nodeName, alpnProto string, conn ne return nil, fmt.Errorf("cannot dial using ALPN-RPC without a target alpn protocol") } - config := c.OutgoingALPNRPCConfig() + config := c.outgoingALPNRPCConfig() if config == nil { return nil, fmt.Errorf("cannot dial via a mesh gateway when outgoing TLS is disabled") } diff --git a/tlsutil/config_test.go b/tlsutil/config_test.go index 42116c985c..d0b8b9d2b6 100644 --- a/tlsutil/config_test.go +++ b/tlsutil/config_test.go @@ -741,22 +741,21 @@ func TestConfigurator_OutgoingRPCTLSDisabled(t *testing.T) { expected bool } variants := []variant{ - {false, false, nil, true}, - {true, false, nil, false}, - {false, true, nil, false}, - {true, true, nil, false}, + {false, false, nil, false}, + {true, false, nil, true}, + {false, true, nil, true}, + {true, true, nil, true}, - // {false, false, &x509.CertPool{}, false}, - {true, false, &x509.CertPool{}, false}, - {false, true, &x509.CertPool{}, false}, - {true, true, &x509.CertPool{}, false}, + {true, false, &x509.CertPool{}, true}, + {false, true, &x509.CertPool{}, true}, + {true, true, &x509.CertPool{}, true}, } for i, v := range variants { info := fmt.Sprintf("case %d", i) c.caPool = v.pool c.base.VerifyOutgoing = v.verify c.base.AutoTLS = v.autoEncryptTLS - require.Equal(t, v.expected, c.outgoingRPCTLSDisabled(), info) + require.Equal(t, v.expected, c.outgoingRPCTLSEnabled(), info) } } @@ -768,7 +767,7 @@ func TestConfigurator_MutualTLSCapable(t *testing.T) { c, err := NewConfigurator(config, nil) require.NoError(t, err) - require.False(t, c.mutualTLSCapable()) + require.False(t, c.MutualTLSCapable()) }) t.Run("ca and no keys", func(t *testing.T) { @@ -779,7 +778,7 @@ func TestConfigurator_MutualTLSCapable(t *testing.T) { c, err := NewConfigurator(config, nil) require.NoError(t, err) - require.False(t, c.mutualTLSCapable()) + require.False(t, c.MutualTLSCapable()) }) t.Run("ca and manual key", func(t *testing.T) { @@ -792,7 +791,7 @@ func TestConfigurator_MutualTLSCapable(t *testing.T) { c, err := NewConfigurator(config, nil) require.NoError(t, err) - require.True(t, c.mutualTLSCapable()) + require.True(t, c.MutualTLSCapable()) }) loadFile := func(t *testing.T, path string) string { @@ -811,7 +810,7 @@ func TestConfigurator_MutualTLSCapable(t *testing.T) { caPEM := loadFile(t, "../test/hostname/CertAuth.crt") require.NoError(t, c.UpdateAutoTLSCA([]string{caPEM})) - require.False(t, c.mutualTLSCapable()) + require.False(t, c.MutualTLSCapable()) }) t.Run("autoencrypt ca and autoencrypt key", func(t *testing.T) { @@ -827,7 +826,7 @@ func TestConfigurator_MutualTLSCapable(t *testing.T) { require.NoError(t, c.UpdateAutoTLSCA([]string{caPEM})) require.NoError(t, c.UpdateAutoTLSCert(certPEM, keyPEM)) - require.True(t, c.mutualTLSCapable()) + require.True(t, c.MutualTLSCapable()) }) } @@ -846,26 +845,10 @@ func TestConfigurator_VerifyIncomingRPC(t *testing.T) { c := Configurator{base: &Config{ VerifyIncomingRPC: true, }} - verify := c.verifyIncomingRPC() + verify := c.VerifyIncomingRPC() require.Equal(t, c.base.VerifyIncomingRPC, verify) } -func TestConfigurator_VerifyIncomingHTTPS(t *testing.T) { - c := Configurator{base: &Config{ - VerifyIncomingHTTPS: true, - }} - verify := c.verifyIncomingHTTPS() - require.Equal(t, c.base.VerifyIncomingHTTPS, verify) -} - -func TestConfigurator_EnableAgentTLSForChecks(t *testing.T) { - c := Configurator{base: &Config{ - EnableAgentTLSForChecks: true, - }} - enabled := c.enableAgentTLSForChecks() - require.Equal(t, c.base.EnableAgentTLSForChecks, enabled) -} - func TestConfigurator_IncomingRPCConfig(t *testing.T) { c, err := NewConfigurator(Config{ VerifyIncomingRPC: true, @@ -911,8 +894,52 @@ func TestConfigurator_IncomingALPNRPCConfig(t *testing.T) { } func TestConfigurator_IncomingHTTPSConfig(t *testing.T) { - c := Configurator{base: &Config{}} - require.Equal(t, []string{"h2", "http/1.1"}, c.IncomingHTTPSConfig().NextProtos) + + // compare tls.Config.GetConfigForClient by nil/not-nil, since Go can not compare + // functions any other way. + cmpClientFunc := cmp.Comparer(func(x, y func(*tls.ClientHelloInfo) (*tls.Config, error)) bool { + return (x == nil && y == nil) || (x != nil && y != nil) + }) + + t.Run("default", func(t *testing.T) { + c, err := NewConfigurator(Config{}, nil) + require.NoError(t, err) + + cfg := c.IncomingHTTPSConfig() + + expected := &tls.Config{ + NextProtos: []string{"h2", "http/1.1"}, + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS10, + GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + }, + } + assertDeepEqual(t, expected, cfg, cmpTLSConfig, cmpClientFunc) + }) + + t.Run("verify incoming", func(t *testing.T) { + c := Configurator{base: &Config{VerifyIncoming: true}} + + cfg := c.IncomingHTTPSConfig() + + expected := &tls.Config{ + NextProtos: []string{"h2", "http/1.1"}, + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS10, + GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + }, + ClientAuth: tls.RequireAndVerifyClientCert, + } + assertDeepEqual(t, expected, cfg, cmpTLSConfig, cmpClientFunc) + }) + +} + +var cmpTLSConfig = cmp.Options{ + cmpopts.IgnoreFields(tls.Config{}, "GetCertificate", "GetClientCertificate"), + cmpopts.IgnoreUnexported(tls.Config{}), } func TestConfigurator_OutgoingTLSConfigForCheck(t *testing.T) { @@ -924,11 +951,6 @@ func TestConfigurator_OutgoingTLSConfigForCheck(t *testing.T) { expected *tls.Config } - cmpTLSConfig := cmp.Options{ - cmpopts.IgnoreFields(tls.Config{}, "GetCertificate", "GetClientCertificate"), - cmpopts.IgnoreUnexported(tls.Config{}), - } - run := func(t *testing.T, tc testCase) { configurator, err := tc.conf() require.NoError(t, err) @@ -1068,7 +1090,7 @@ func TestConfigurator_OutgoingRPCConfig(t *testing.T) { func TestConfigurator_OutgoingALPNRPCConfig(t *testing.T) { c := &Configurator{base: &Config{}} - require.Nil(t, c.OutgoingALPNRPCConfig()) + require.Nil(t, c.outgoingALPNRPCConfig()) c, err := NewConfigurator(Config{ VerifyOutgoing: false, // ignored, assumed true @@ -1078,7 +1100,7 @@ func TestConfigurator_OutgoingALPNRPCConfig(t *testing.T) { }, nil) require.NoError(t, err) - tlsConf := c.OutgoingALPNRPCConfig() + tlsConf := c.outgoingALPNRPCConfig() require.NotNil(t, tlsConf) require.Equal(t, tls.RequireAndVerifyClientCert, tlsConf.ClientAuth) require.False(t, tlsConf.InsecureSkipVerify)