diff --git a/.changelog/9512.txt b/.changelog/9512.txt new file mode 100644 index 0000000000..b1a8bdebd5 --- /dev/null +++ b/.changelog/9512.txt @@ -0,0 +1,3 @@ +```release-note:bug +client: properly set GRPC over RPC magic numbers when encryption was not set or partially set in the cluster with streaming enabled +``` diff --git a/agent/grpc/client.go b/agent/grpc/client.go index 7eb80070f3..9fdfa54e62 100644 --- a/agent/grpc/client.go +++ b/agent/grpc/client.go @@ -35,9 +35,10 @@ type TLSWrapper func(dc string, conn net.Conn) (net.Conn, error) type dialer func(context.Context, string) (net.Conn, error) -func NewClientConnPool(servers ServerLocator, tls TLSWrapper) *ClientConnPool { +// NewClientConnPool create new GRPC client pool to connect to servers using GRPC over RPC +func NewClientConnPool(servers ServerLocator, tls TLSWrapper, useTLSForDC func(dc string) bool) *ClientConnPool { return &ClientConnPool{ - dialer: newDialer(servers, tls), + dialer: newDialer(servers, tls, useTLSForDC), servers: servers, conns: make(map[string]*grpc.ClientConn), } @@ -74,7 +75,7 @@ func (c *ClientConnPool) ClientConn(datacenter string) (*grpc.ClientConn, error) // newDialer returns a gRPC dialer function that conditionally wraps the connection // with TLS based on the Server.useTLS value. -func newDialer(servers ServerLocator, wrapper TLSWrapper) func(context.Context, string) (net.Conn, error) { +func newDialer(servers ServerLocator, wrapper TLSWrapper, useTLSForDC func(dc string) bool) func(context.Context, string) (net.Conn, error) { return func(ctx context.Context, addr string) (net.Conn, error) { d := net.Dialer{} conn, err := d.DialContext(ctx, "tcp", addr) @@ -88,7 +89,7 @@ func newDialer(servers ServerLocator, wrapper TLSWrapper) func(context.Context, return nil, err } - if server.UseTLS { + if server.UseTLS && useTLSForDC(server.Datacenter) { if wrapper == nil { conn.Close() return nil, fmt.Errorf("TLS enabled but got nil TLS wrapper") diff --git a/agent/grpc/client_test.go b/agent/grpc/client_test.go index 38ecc40aa7..5028e34fa9 100644 --- a/agent/grpc/client_test.go +++ b/agent/grpc/client_test.go @@ -18,6 +18,11 @@ import ( "github.com/hashicorp/consul/tlsutil" ) +// useTLSForDcAlwaysTrue tell GRPC to always return the TLS is enabled +func useTLSForDcAlwaysTrue(_ string) bool { + return true +} + func TestNewDialer_WithTLSWrapper(t *testing.T) { lis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -37,7 +42,7 @@ func TestNewDialer_WithTLSWrapper(t *testing.T) { called = true return conn, nil } - dial := newDialer(builder, wrapper) + dial := newDialer(builder, wrapper, useTLSForDcAlwaysTrue) ctx := context.Background() conn, err := dial(ctx, lis.Addr().String()) require.NoError(t, err) @@ -63,7 +68,7 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) { res.AddServer(srv.Metadata()) t.Cleanup(srv.shutdown) - pool := NewClientConnPool(res, TLSWrapper(tlsConf.OutgoingRPCWrapper())) + pool := NewClientConnPool(res, TLSWrapper(tlsConf.OutgoingRPCWrapper()), tlsConf.UseTLS) conn, err := pool.ClientConn("dc1") require.NoError(t, err) @@ -82,7 +87,7 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { count := 4 res := resolver.NewServerResolverBuilder(newConfig(t)) registerWithGRPC(t, res) - pool := NewClientConnPool(res, nil) + pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue) for i := 0; i < count; i++ { name := fmt.Sprintf("server-%d", i) @@ -119,7 +124,7 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Rebalance(t *testing.T) { count := 5 res := resolver.NewServerResolverBuilder(newConfig(t)) registerWithGRPC(t, res) - pool := NewClientConnPool(res, nil) + pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue) for i := 0; i < count; i++ { name := fmt.Sprintf("server-%d", i) @@ -168,7 +173,7 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) { res := resolver.NewServerResolverBuilder(newConfig(t)) registerWithGRPC(t, res) - pool := NewClientConnPool(res, nil) + pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue) for _, dc := range dcs { name := "server-0-" + dc diff --git a/agent/setup.go b/agent/setup.go index ccf9c08d14..34cc9ac893 100644 --- a/agent/setup.go +++ b/agent/setup.go @@ -101,7 +101,7 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer) (BaseDeps, error) builder := resolver.NewServerResolverBuilder(resolver.Config{}) registerWithGRPC(builder) - d.GRPCConnPool = grpc.NewClientConnPool(builder, grpc.TLSWrapper(d.TLSConfigurator.OutgoingRPCWrapper())) + d.GRPCConnPool = grpc.NewClientConnPool(builder, grpc.TLSWrapper(d.TLSConfigurator.OutgoingRPCWrapper()), d.TLSConfigurator.UseTLS) d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), builder) diff --git a/agent/streaming_test.go b/agent/streaming_test.go new file mode 100644 index 0000000000..0f45ad9ed4 --- /dev/null +++ b/agent/streaming_test.go @@ -0,0 +1,107 @@ +package agent + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/testrpc" +) + +func testGRPCStreamingWorking(t *testing.T, config string) { + if testing.Short() { + t.Skip("too slow for testing.Short") + } + + a := NewTestAgent(t, config) + defer a.Shutdown() + + testrpc.WaitForLeader(t, a.RPC, "dc1") + + req, _ := http.NewRequest("GET", "/v1/health/service/consul?index=3", nil) + resp := httptest.NewRecorder() + _, err := a.srv.HealthServiceNodes(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + assertIndex(t, resp) + require.NotEmpty(t, resp.Header().Get("X-Consul-Index")) +} + +func TestGRPCWithTLSConfigs(t *testing.T) { + t.Parallel() + testCases := []struct { + name string + config string + }{ + { + name: "no-tls", + config: "", + }, + { + name: "tls-all-enabled", + config: ` + # tls + ca_file = "../test/hostname/CertAuth.crt" + cert_file = "../test/hostname/Bob.crt" + key_file = "../test/hostname/Bob.key" + verify_incoming = true + verify_outgoing = true + verify_server_hostname = true + `, + }, + { + name: "tls ready no verify incoming", + config: ` + # tls + ca_file = "../test/hostname/CertAuth.crt" + cert_file = "../test/hostname/Bob.crt" + key_file = "../test/hostname/Bob.key" + verify_incoming = false + verify_outgoing = true + verify_server_hostname = false + `, + }, + { + name: "tls ready no verify outgoing and incoming", + config: ` + # tls + ca_file = "../test/hostname/CertAuth.crt" + cert_file = "../test/hostname/Bob.crt" + key_file = "../test/hostname/Bob.key" + verify_incoming = false + verify_outgoing = false + verify_server_hostname = false + `, + }, + { + name: "tls ready, all defaults", + config: ` + # tls + ca_file = "../test/hostname/CertAuth.crt" + cert_file = "../test/hostname/Bob.crt" + key_file = "../test/hostname/Bob.key" + `, + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + dataDir := testutil.TempDir(t, "agent") // we manage the data dir + cfg := `data_dir = "` + dataDir + `" + domain = "consul" + node_name = "my-fancy-server" + datacenter = "dc1" + primary_datacenter = "dc1" + rpc { + enable_streaming = true + } + use_streaming_backend = true + ` + tt.config + testGRPCStreamingWorking(t, cfg) + }) + } +}