From 106d781dc99cb5f7f616d1de5bf1f1832942ec48 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Tue, 8 Sep 2020 15:22:35 -0400 Subject: [PATCH 01/11] subscribe: add initial impl from streaming-rpc-final branch Co-authored-by: Paul Banks --- agent/subscribe/subscribe.go | 187 +++++++++++++++++++++++++++++++++++ 1 file changed, 187 insertions(+) create mode 100644 agent/subscribe/subscribe.go diff --git a/agent/subscribe/subscribe.go b/agent/subscribe/subscribe.go new file mode 100644 index 0000000000..d2d7083889 --- /dev/null +++ b/agent/subscribe/subscribe.go @@ -0,0 +1,187 @@ +package subscribe + +import ( + "github.com/hashicorp/consul/agent/consul/stream" + "github.com/hashicorp/consul/proto/pbevent" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-uuid" +) + +// Server implements a StateChangeSubscriptionServer for accepting SubscribeRequests, +// and sending events to the subscription topic. +type Server struct { + srv *Server + logger hclog.Logger +} + +var _ pbevent.StateChangeSubscriptionServer = (*Server)(nil) + +func (h *Server) Subscribe(req *pbevent.SubscribeRequest, serverStream pbevent.StateChangeSubscription_SubscribeServer) error { + // streamID is just used for message correlation in trace logs and not + // populated normally. + var streamID string + var err error + + if h.logger.IsTrace() { + // TODO(banks) it might be nice one day to replace this with OpenTracing ID + // if one is set etc. but probably pointless until we support that properly + // in other places so it's actually propagated properly. For now this just + // makes lifetime of a stream more traceable in our regular server logs for + // debugging/dev. + streamID, err = uuid.GenerateUUID() + if err != nil { + return err + } + } + + // Forward the request to a remote DC if applicable. + if req.Datacenter != "" && req.Datacenter != h.srv.config.Datacenter { + return h.forwardAndProxy(req, serverStream, streamID) + } + + h.srv.logger.Trace("new subscription", + "topic", req.Topic.String(), + "key", req.Key, + "index", req.Index, + "stream_id", streamID, + ) + + var sentCount uint64 + defer h.srv.logger.Trace("subscription closed", "stream_id", streamID) + + // Resolve the token and create the ACL filter. + // TODO: handle token expiry gracefully... + authz, err := h.srv.ResolveToken(req.Token) + if err != nil { + return err + } + aclFilter := newACLFilter(authz, h.srv.logger, h.srv.config.ACLEnforceVersion8) + + state := h.srv.fsm.State() + + // Register a subscription on this topic/key with the FSM. + sub, err := state.Subscribe(serverStream.Context(), req) + if err != nil { + return err + } + defer state.Unsubscribe(req) + + // Deliver the events + for { + events, err := sub.Next() + if err == stream.ErrSubscriptionReload { + event := pbevent.Event{ + Payload: &pbevent.Event_ResetStream{ResetStream: true}, + } + if err := serverStream.Send(&event); err != nil { + return err + } + h.srv.logger.Trace("subscription reloaded", + "stream_id", streamID, + ) + return nil + } + if err != nil { + return err + } + + aclFilter.filterStreamEvents(&events) + + snapshotDone := false + if len(events) == 1 { + if events[0].GetEndOfSnapshot() { + snapshotDone = true + h.srv.logger.Trace("snapshot complete", + "index", events[0].Index, + "sent", sentCount, + "stream_id", streamID, + ) + } else if events[0].GetResumeStream() { + snapshotDone = true + h.srv.logger.Trace("resuming stream", + "index", events[0].Index, + "sent", sentCount, + "stream_id", streamID, + ) + } else if snapshotDone { + // Count this event too in the normal case as "sent" the above cases + // only show the number of events sent _before_ the snapshot ended. + h.srv.logger.Trace("sending events", + "index", events[0].Index, + "sent", sentCount, + "batch_size", 1, + "stream_id", streamID, + ) + } + sentCount++ + if err := serverStream.Send(&events[0]); err != nil { + return err + } + } else if len(events) > 1 { + e := &pbevent.Event{ + Topic: req.Topic, + Key: req.Key, + Index: events[0].Index, + Payload: &pbevent.Event_EventBatch{ + EventBatch: &pbevent.EventBatch{ + Events: pbevent.EventBatchEventsFromEventSlice(events), + }, + }, + } + sentCount += uint64(len(events)) + h.srv.logger.Trace("sending events", + "index", events[0].Index, + "sent", sentCount, + "batch_size", len(events), + "stream_id", streamID, + ) + if err := serverStream.Send(e); err != nil { + return err + } + } + } +} + +func (h *Server) forwardAndProxy( + req *pbevent.SubscribeRequest, + serverStream pbevent.StateChangeSubscription_SubscribeServer, + streamID string) error { + + conn, err := h.srv.grpcClient.GRPCConn(req.Datacenter) + if err != nil { + return err + } + + h.logger.Trace("forwarding to another DC", + "dc", req.Datacenter, + "topic", req.Topic.String(), + "key", req.Key, + "index", req.Index, + "stream_id", streamID, + ) + + defer func() { + h.logger.Trace("forwarded stream closed", + "dc", req.Datacenter, + "stream_id", streamID, + ) + }() + + // Open a Subscribe call to the remote DC. + client := pbevent.NewConsulClient(conn) + streamHandle, err := client.Subscribe(serverStream.Context(), req) + if err != nil { + return err + } + + // Relay the events back to the client. + for { + event, err := streamHandle.Recv() + if err != nil { + return err + } + if err := serverStream.Send(event); err != nil { + return err + } + } +} From f4ea3066fbcb1d3ede0d086a3356ef0681591b38 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Wed, 9 Sep 2020 14:04:33 -0400 Subject: [PATCH 02/11] subscribe: add commented out test cases Co-authored-by: Paul Banks --- agent/subscribe/subscribe_test.go | 1434 +++++++++++++++++++++++++++++ 1 file changed, 1434 insertions(+) create mode 100644 agent/subscribe/subscribe_test.go diff --git a/agent/subscribe/subscribe_test.go b/agent/subscribe/subscribe_test.go new file mode 100644 index 0000000000..3c490bd323 --- /dev/null +++ b/agent/subscribe/subscribe_test.go @@ -0,0 +1,1434 @@ +package subscribe + +/* TODO +func TestStreaming_Subscribe(t *testing.T) { + t.Parallel() + + require := require.New(t) + dir1, server := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.Bootstrap = true + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir1) + defer server.Shutdown() + codec := rpcClient(t, server) + defer codec.Close() + + dir2, client := testClientWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.NodeName = uniqueNodeName(t.Name()) + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir2) + defer client.Shutdown() + + // Try to join + testrpc.WaitForLeader(t, server.RPC, "dc1") + joinLAN(t, client, server) + testrpc.WaitForTestAgent(t, client.RPC, "dc1") + + // Register a dummy node with a service we don't care about, to make sure + // we don't see updates for it. + { + req := &structs.RegisterRequest{ + Node: "other", + Address: "2.3.4.5", + Datacenter: "dc1", + Service: &structs.NodeService{ + ID: "api1", + Service: "api", + Address: "2.3.4.5", + Port: 9000, + }, + } + var out struct{} + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + } + + // Register a dummy node with our service on it. + { + req := &structs.RegisterRequest{ + Node: "node1", + Address: "3.4.5.6", + Datacenter: "dc1", + Service: &structs.NodeService{ + ID: "redis1", + Service: "redis", + Address: "3.4.5.6", + Port: 8080, + }, + } + var out struct{} + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + } + + // Register a test node to be updated later. + req := &structs.RegisterRequest{ + Node: "node2", + Address: "1.2.3.4", + Datacenter: "dc1", + Service: &structs.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + }, + } + var out struct{} + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + + // Start a Subscribe call to our streaming endpoint. + conn, err := client.GRPCConn() + require.NoError(err) + + streamClient := pbsubscribe.NewConsulClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{ + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "redis", + }) + require.NoError(err) + + // Start a goroutine to read updates off the pbsubscribe. + eventCh := make(chan *pbsubscribe.Event, 0) + go testSendEvents(t, eventCh, streamHandle) + + var snapshotEvents []*pbsubscribe.Event + for i := 0; i < 3; i++ { + select { + case event := <-eventCh: + snapshotEvents = append(snapshotEvents, event) + case <-time.After(3 * time.Second): + t.Fatalf("did not receive events past %d", len(snapshotEvents)) + } + } + + expected := []*pbsubscribe.Event{ + { + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "redis", + Payload: &pbsubscribe.Event_ServiceHealth{ + ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ + Op: pbsubscribe.CatalogOp_Register, + CheckServiceNode: &pbsubscribe.CheckServiceNode{ + Node: &pbsubscribe.Node{ + Node: "node1", + Datacenter: "dc1", + Address: "3.4.5.6", + }, + Service: &pbsubscribe.NodeService{ + ID: "redis1", + Service: "redis", + Address: "3.4.5.6", + Port: 8080, + Weights: &pbsubscribe.Weights{Passing: 1, Warning: 1}, + // Sad empty state + Proxy: pbsubscribe.ConnectProxyConfig{ + MeshGateway: &pbsubscribe.MeshGatewayConfig{}, + Expose: &pbsubscribe.ExposeConfig{}, + }, + EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, + }, + }, + }, + }, + }, + { + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "redis", + Payload: &pbsubscribe.Event_ServiceHealth{ + ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ + Op: pbsubscribe.CatalogOp_Register, + CheckServiceNode: &pbsubscribe.CheckServiceNode{ + Node: &pbsubscribe.Node{ + Node: "node2", + Datacenter: "dc1", + Address: "1.2.3.4", + }, + Service: &pbsubscribe.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + Weights: &pbsubscribe.Weights{Passing: 1, Warning: 1}, + // Sad empty state + Proxy: pbsubscribe.ConnectProxyConfig{ + MeshGateway: &pbsubscribe.MeshGatewayConfig{}, + Expose: &pbsubscribe.ExposeConfig{}, + }, + EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, + }, + }, + }, + }, + }, + { + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "redis", + Payload: &pbsubscribe.Event_EndOfSnapshot{EndOfSnapshot: true}, + }, + } + + require.Len(snapshotEvents, 3) + for i := 0; i < 2; i++ { + // Fix up the index + expected[i].Index = snapshotEvents[i].Index + node := expected[i].GetServiceHealth().CheckServiceNode + node.Node.RaftIndex = snapshotEvents[i].GetServiceHealth().CheckServiceNode.Node.RaftIndex + node.Service.RaftIndex = snapshotEvents[i].GetServiceHealth().CheckServiceNode.Service.RaftIndex + } + // Fix index on snapshot event + expected[2].Index = snapshotEvents[2].Index + + requireEqualProtos(t, expected, snapshotEvents) + + // Update the registration by adding a check. + req.Check = &structs.HealthCheck{ + Node: "node2", + CheckID: types.CheckID("check1"), + ServiceID: "redis1", + ServiceName: "redis", + Name: "check 1", + } + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + + // Make sure we get the event for the diff. + select { + case event := <-eventCh: + expected := &pbsubscribe.Event{ + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "redis", + Payload: &pbsubscribe.Event_ServiceHealth{ + ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ + Op: pbsubscribe.CatalogOp_Register, + CheckServiceNode: &pbsubscribe.CheckServiceNode{ + Node: &pbsubscribe.Node{ + Node: "node2", + Datacenter: "dc1", + Address: "1.2.3.4", + RaftIndex: pbsubscribe.RaftIndex{CreateIndex: 13, ModifyIndex: 13}, + }, + Service: &pbsubscribe.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + RaftIndex: pbsubscribe.RaftIndex{CreateIndex: 13, ModifyIndex: 13}, + Weights: &pbsubscribe.Weights{Passing: 1, Warning: 1}, + // Sad empty state + Proxy: pbsubscribe.ConnectProxyConfig{ + MeshGateway: &pbsubscribe.MeshGatewayConfig{}, + Expose: &pbsubscribe.ExposeConfig{}, + }, + EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, + }, + Checks: []*pbsubscribe.HealthCheck{ + { + CheckID: "check1", + Name: "check 1", + Node: "node2", + Status: "critical", + ServiceID: "redis1", + ServiceName: "redis", + RaftIndex: pbsubscribe.RaftIndex{CreateIndex: 14, ModifyIndex: 14}, + EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, + }, + }, + }, + }, + }, + } + // Fix up the index + expected.Index = event.Index + node := expected.GetServiceHealth().CheckServiceNode + node.Node.RaftIndex = event.GetServiceHealth().CheckServiceNode.Node.RaftIndex + node.Service.RaftIndex = event.GetServiceHealth().CheckServiceNode.Service.RaftIndex + node.Checks[0].RaftIndex = event.GetServiceHealth().CheckServiceNode.Checks[0].RaftIndex + requireEqualProtos(t, expected, event) + case <-time.After(3 * time.Second): + t.Fatal("never got event") + } + + // Wait and make sure there aren't any more events coming. + select { + case event := <-eventCh: + t.Fatalf("got another event: %v", event) + case <-time.After(500 * time.Millisecond): + } +} + +func TestStreaming_Subscribe_MultiDC(t *testing.T) { + t.Parallel() + + require := require.New(t) + dir1, server1 := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.Bootstrap = true + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir1) + defer server1.Shutdown() + + dir2, server2 := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc2" + c.Bootstrap = true + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir2) + defer server2.Shutdown() + codec := rpcClient(t, server2) + defer codec.Close() + + dir3, client := testClientWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.NodeName = uniqueNodeName(t.Name()) + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir3) + defer client.Shutdown() + + // Join the servers via WAN + joinWAN(t, server2, server1) + testrpc.WaitForLeader(t, server1.RPC, "dc1") + testrpc.WaitForLeader(t, server2.RPC, "dc2") + + joinLAN(t, client, server1) + testrpc.WaitForTestAgent(t, client.RPC, "dc1") + + // Register a dummy node in dc2 with a service we don't care about, + // to make sure we don't see updates for it. + { + req := &structs.RegisterRequest{ + Node: "other", + Address: "2.3.4.5", + Datacenter: "dc2", + Service: &structs.NodeService{ + ID: "api1", + Service: "api", + Address: "2.3.4.5", + Port: 9000, + }, + } + var out struct{} + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + } + + // Register a dummy node with our service on it, again in dc2. + { + req := &structs.RegisterRequest{ + Node: "node1", + Address: "3.4.5.6", + Datacenter: "dc2", + Service: &structs.NodeService{ + ID: "redis1", + Service: "redis", + Address: "3.4.5.6", + Port: 8080, + }, + } + var out struct{} + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + } + + // Register a test node in dc2 to be updated later. + req := &structs.RegisterRequest{ + Node: "node2", + Address: "1.2.3.4", + Datacenter: "dc2", + Service: &structs.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + }, + } + var out struct{} + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + + // Start a cross-DC Subscribe call to our streaming endpoint, specifying dc2. + conn, err := client.GRPCConn() + require.NoError(err) + + streamClient := pbsubscribe.NewConsulClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{ + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "redis", + Datacenter: "dc2", + }) + require.NoError(err) + + // Start a goroutine to read updates off the pbsubscribe. + eventCh := make(chan *pbsubscribe.Event, 0) + go testSendEvents(t, eventCh, streamHandle) + + var snapshotEvents []*pbsubscribe.Event + for i := 0; i < 3; i++ { + select { + case event := <-eventCh: + snapshotEvents = append(snapshotEvents, event) + case <-time.After(3 * time.Second): + t.Fatalf("did not receive events past %d", len(snapshotEvents)) + } + } + + expected := []*pbsubscribe.Event{ + { + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "redis", + Payload: &pbsubscribe.Event_ServiceHealth{ + ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ + Op: pbsubscribe.CatalogOp_Register, + CheckServiceNode: &pbsubscribe.CheckServiceNode{ + Node: &pbsubscribe.Node{ + Node: "node1", + Datacenter: "dc2", + Address: "3.4.5.6", + }, + Service: &pbsubscribe.NodeService{ + ID: "redis1", + Service: "redis", + Address: "3.4.5.6", + Port: 8080, + Weights: &pbsubscribe.Weights{Passing: 1, Warning: 1}, + // Sad empty state + Proxy: pbsubscribe.ConnectProxyConfig{ + MeshGateway: &pbsubscribe.MeshGatewayConfig{}, + Expose: &pbsubscribe.ExposeConfig{}, + }, + EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, + }, + }, + }, + }, + }, + { + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "redis", + Payload: &pbsubscribe.Event_ServiceHealth{ + ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ + Op: pbsubscribe.CatalogOp_Register, + CheckServiceNode: &pbsubscribe.CheckServiceNode{ + Node: &pbsubscribe.Node{ + Node: "node2", + Datacenter: "dc2", + Address: "1.2.3.4", + }, + Service: &pbsubscribe.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + Weights: &pbsubscribe.Weights{Passing: 1, Warning: 1}, + // Sad empty state + Proxy: pbsubscribe.ConnectProxyConfig{ + MeshGateway: &pbsubscribe.MeshGatewayConfig{}, + Expose: &pbsubscribe.ExposeConfig{}, + }, + EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, + }, + }, + }, + }, + }, + { + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "redis", + Payload: &pbsubscribe.Event_EndOfSnapshot{EndOfSnapshot: true}, + }, + } + + require.Len(snapshotEvents, 3) + for i := 0; i < 2; i++ { + // Fix up the index + expected[i].Index = snapshotEvents[i].Index + node := expected[i].GetServiceHealth().CheckServiceNode + node.Node.RaftIndex = snapshotEvents[i].GetServiceHealth().CheckServiceNode.Node.RaftIndex + node.Service.RaftIndex = snapshotEvents[i].GetServiceHealth().CheckServiceNode.Service.RaftIndex + } + expected[2].Index = snapshotEvents[2].Index + requireEqualProtos(t, expected, snapshotEvents) + + // Update the registration by adding a check. + req.Check = &structs.HealthCheck{ + Node: "node2", + CheckID: types.CheckID("check1"), + ServiceID: "redis1", + ServiceName: "redis", + Name: "check 1", + } + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + + // Make sure we get the event for the diff. + select { + case event := <-eventCh: + expected := &pbsubscribe.Event{ + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "redis", + Payload: &pbsubscribe.Event_ServiceHealth{ + ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ + Op: pbsubscribe.CatalogOp_Register, + CheckServiceNode: &pbsubscribe.CheckServiceNode{ + Node: &pbsubscribe.Node{ + Node: "node2", + Datacenter: "dc2", + Address: "1.2.3.4", + RaftIndex: pbsubscribe.RaftIndex{CreateIndex: 13, ModifyIndex: 13}, + }, + Service: &pbsubscribe.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + RaftIndex: pbsubscribe.RaftIndex{CreateIndex: 13, ModifyIndex: 13}, + Weights: &pbsubscribe.Weights{Passing: 1, Warning: 1}, + // Sad empty state + Proxy: pbsubscribe.ConnectProxyConfig{ + MeshGateway: &pbsubscribe.MeshGatewayConfig{}, + Expose: &pbsubscribe.ExposeConfig{}, + }, + EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, + }, + Checks: []*pbsubscribe.HealthCheck{ + { + CheckID: "check1", + Name: "check 1", + Node: "node2", + Status: "critical", + ServiceID: "redis1", + ServiceName: "redis", + RaftIndex: pbsubscribe.RaftIndex{CreateIndex: 14, ModifyIndex: 14}, + EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, + }, + }, + }, + }, + }, + } + // Fix up the index + expected.Index = event.Index + node := expected.GetServiceHealth().CheckServiceNode + node.Node.RaftIndex = event.GetServiceHealth().CheckServiceNode.Node.RaftIndex + node.Service.RaftIndex = event.GetServiceHealth().CheckServiceNode.Service.RaftIndex + node.Checks[0].RaftIndex = event.GetServiceHealth().CheckServiceNode.Checks[0].RaftIndex + requireEqualProtos(t, expected, event) + case <-time.After(3 * time.Second): + t.Fatal("never got event") + } + + // Wait and make sure there aren't any more events coming. + select { + case event := <-eventCh: + t.Fatalf("got another event: %v", event) + case <-time.After(500 * time.Millisecond): + } +} + +func TestStreaming_Subscribe_SkipSnapshot(t *testing.T) { + t.Parallel() + + require := require.New(t) + dir1, server := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.Bootstrap = true + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir1) + defer server.Shutdown() + codec := rpcClient(t, server) + defer codec.Close() + + dir2, client := testClientWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.NodeName = uniqueNodeName(t.Name()) + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir2) + defer client.Shutdown() + + // Try to join + testrpc.WaitForLeader(t, server.RPC, "dc1") + joinLAN(t, client, server) + testrpc.WaitForTestAgent(t, client.RPC, "dc1") + + // Register a dummy node with our service on it. + { + req := &structs.RegisterRequest{ + Node: "node1", + Address: "3.4.5.6", + Datacenter: "dc1", + Service: &structs.NodeService{ + ID: "redis1", + Service: "redis", + Address: "3.4.5.6", + Port: 8080, + }, + } + var out struct{} + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + } + + // Start a Subscribe call to our streaming endpoint. + conn, err := client.GRPCConn() + require.NoError(err) + + streamClient := pbsubscribe.NewConsulClient(conn) + + var index uint64 + { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis"}) + require.NoError(err) + + // Start a goroutine to read updates off the pbsubscribe. + eventCh := make(chan *pbsubscribe.Event, 0) + go testSendEvents(t, eventCh, streamHandle) + + var snapshotEvents []*pbsubscribe.Event + for i := 0; i < 2; i++ { + select { + case event := <-eventCh: + snapshotEvents = append(snapshotEvents, event) + case <-time.After(3 * time.Second): + t.Fatalf("did not receive events past %d", len(snapshotEvents)) + } + } + + // Save the index from the event + index = snapshotEvents[0].Index + } + + // Start another Subscribe call passing the index from the last event. + { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{ + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "redis", + Index: index, + }) + require.NoError(err) + + // Start a goroutine to read updates off the pbsubscribe. + eventCh := make(chan *pbsubscribe.Event, 0) + go testSendEvents(t, eventCh, streamHandle) + + // We should get no snapshot and the first event should be "resume stream" + select { + case event := <-eventCh: + require.True(event.GetResumeStream()) + case <-time.After(500 * time.Millisecond): + t.Fatalf("never got event") + } + + // Wait and make sure there aren't any events coming. The server shouldn't send + // a snapshot and we haven't made any updates to the catalog that would trigger + // more events. + select { + case event := <-eventCh: + t.Fatalf("got another event: %v", event) + case <-time.After(500 * time.Millisecond): + } + } +} + +func TestStreaming_Subscribe_FilterACL(t *testing.T) { + t.Parallel() + + require := require.New(t) + dir, _, server, codec := testACLFilterServerWithConfigFn(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + c.ACLDefaultPolicy = "deny" + c.ACLEnforceVersion8 = true + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir) + defer server.Shutdown() + defer codec.Close() + + dir2, client := testClientWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.NodeName = uniqueNodeName(t.Name()) + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir2) + defer client.Shutdown() + + // Try to join + testrpc.WaitForLeader(t, server.RPC, "dc1") + joinLAN(t, client, server) + testrpc.WaitForTestAgent(t, client.RPC, "dc1", testrpc.WithToken("root")) + + // Create a policy for the test token. + policyReq := structs.ACLPolicySetRequest{ + Datacenter: "dc1", + Policy: structs.ACLPolicy{ + Description: "foobar", + Name: "baz", + Rules: fmt.Sprintf(` + service "foo" { + policy = "write" + } + node "%s" { + policy = "write" + } + `, server.config.NodeName), + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + resp := structs.ACLPolicy{} + require.NoError(msgpackrpc.CallWithCodec(codec, "ACL.PolicySet", &policyReq, &resp)) + + // Create a new token that only has access to one node. + var token structs.ACLToken + arg := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Policies: []structs.ACLTokenPolicyLink{ + structs.ACLTokenPolicyLink{ + ID: resp.ID, + }, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + require.NoError(msgpackrpc.CallWithCodec(codec, "ACL.TokenSet", &arg, &token)) + auth, err := server.ResolveToken(token.SecretID) + require.NoError(err) + require.Equal(auth.NodeRead("denied", nil), acl.Deny) + + // Register another instance of service foo on a fake node the token doesn't have access to. + regArg := structs.RegisterRequest{ + Datacenter: "dc1", + Node: "denied", + Address: "127.0.0.1", + Service: &structs.NodeService{ + ID: "foo", + Service: "foo", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", ®Arg, nil)) + + // Set up the gRPC client. + conn, err := client.GRPCConn() + require.NoError(err) + streamClient := pbsubscribe.NewConsulClient(conn) + + // Start a Subscribe call to our streaming endpoint for the service we have access to. + { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{ + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "foo", + Token: token.SecretID, + }) + require.NoError(err) + + // Start a goroutine to read updates off the pbsubscribe. + eventCh := make(chan *pbsubscribe.Event, 0) + go testSendEvents(t, eventCh, streamHandle) + + // Read events off the pbsubscribe. We should not see any events for the filtered node. + var snapshotEvents []*pbsubscribe.Event + for i := 0; i < 2; i++ { + select { + case event := <-eventCh: + snapshotEvents = append(snapshotEvents, event) + case <-time.After(5 * time.Second): + t.Fatalf("did not receive events past %d", len(snapshotEvents)) + } + } + require.Len(snapshotEvents, 2) + require.Equal("foo", snapshotEvents[0].GetServiceHealth().CheckServiceNode.Service.Service) + require.Equal(server.config.NodeName, snapshotEvents[0].GetServiceHealth().CheckServiceNode.Node.Node) + require.True(snapshotEvents[1].GetEndOfSnapshot()) + + // Update the service with a new port to trigger a new event. + regArg := structs.RegisterRequest{ + Datacenter: "dc1", + Node: server.config.NodeName, + Address: "127.0.0.1", + Service: &structs.NodeService{ + ID: "foo", + Service: "foo", + Port: 1234, + }, + Check: &structs.HealthCheck{ + CheckID: "service:foo", + Name: "service:foo", + ServiceID: "foo", + Status: api.HealthPassing, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", ®Arg, nil)) + + select { + case event := <-eventCh: + service := event.GetServiceHealth().CheckServiceNode.Service + require.Equal("foo", service.Service) + require.Equal(1234, service.Port) + case <-time.After(5 * time.Second): + t.Fatalf("did not receive events past %d", len(snapshotEvents)) + } + + // Now update the service on the denied node and make sure we don't see an event. + regArg = structs.RegisterRequest{ + Datacenter: "dc1", + Node: "denied", + Address: "127.0.0.1", + Service: &structs.NodeService{ + ID: "foo", + Service: "foo", + Port: 2345, + }, + Check: &structs.HealthCheck{ + CheckID: "service:foo", + Name: "service:foo", + ServiceID: "foo", + Status: api.HealthPassing, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", ®Arg, nil)) + + select { + case event := <-eventCh: + t.Fatalf("should not have received event: %v", event) + case <-time.After(500 * time.Millisecond): + } + } + + // Start another subscribe call for bar, which the token shouldn't have access to. + { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{ + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "bar", + Token: token.SecretID, + }) + require.NoError(err) + + // Start a goroutine to read updates off the pbsubscribe. + eventCh := make(chan *pbsubscribe.Event, 0) + go testSendEvents(t, eventCh, streamHandle) + + select { + case event := <-eventCh: + require.True(event.GetEndOfSnapshot()) + case <-time.After(3 * time.Second): + t.Fatal("did not receive event") + } + + // Update the service and make sure we don't get a new event. + regArg := structs.RegisterRequest{ + Datacenter: "dc1", + Node: server.config.NodeName, + Address: "127.0.0.1", + Service: &structs.NodeService{ + ID: "bar", + Service: "bar", + Port: 2345, + }, + Check: &structs.HealthCheck{ + CheckID: "service:bar", + Name: "service:bar", + ServiceID: "bar", + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", ®Arg, nil)) + + select { + case event := <-eventCh: + t.Fatalf("should not have received event: %v", event) + case <-time.After(500 * time.Millisecond): + } + } +} + +func TestStreaming_Subscribe_ACLUpdate(t *testing.T) { + t.Parallel() + + require := require.New(t) + dir, _, server, codec := testACLFilterServerWithConfigFn(t, func(c *Config) { + c.ACLDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLMasterToken = "root" + c.ACLDefaultPolicy = "deny" + c.ACLEnforceVersion8 = true + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir) + defer server.Shutdown() + defer codec.Close() + + dir2, client := testClientWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.NodeName = uniqueNodeName(t.Name()) + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir2) + defer client.Shutdown() + + // Try to join + testrpc.WaitForLeader(t, server.RPC, "dc1") + joinLAN(t, client, server) + testrpc.WaitForTestAgent(t, client.RPC, "dc1", testrpc.WithToken("root")) + + // Create a new token/policy that only has access to one node. + var token structs.ACLToken + + policy, err := upsertTestPolicyWithRules(codec, "root", "dc1", fmt.Sprintf(` + service "foo" { + policy = "write" + } + node "%s" { + policy = "write" + } + `, server.config.NodeName)) + require.NoError(err) + + arg := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "Service/node token", + Policies: []structs.ACLTokenPolicyLink{ + structs.ACLTokenPolicyLink{ + ID: policy.ID, + }, + }, + Local: false, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + + require.NoError(msgpackrpc.CallWithCodec(codec, "ACL.TokenSet", &arg, &token)) + auth, err := server.ResolveToken(token.SecretID) + require.NoError(err) + require.Equal(auth.NodeRead("denied", nil), acl.Deny) + + // Set up the gRPC client. + conn, err := client.GRPCConn() + require.NoError(err) + streamClient := pbsubscribe.NewConsulClient(conn) + + // Start a Subscribe call to our streaming endpoint for the service we have access to. + { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{ + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "foo", + Token: token.SecretID, + }) + require.NoError(err) + + // Start a goroutine to read updates off the pbsubscribe. + eventCh := make(chan *pbsubscribe.Event, 0) + go testSendEvents(t, eventCh, streamHandle) + + // Read events off the pbsubscribe. + var snapshotEvents []*pbsubscribe.Event + for i := 0; i < 2; i++ { + select { + case event := <-eventCh: + snapshotEvents = append(snapshotEvents, event) + case <-time.After(5 * time.Second): + t.Fatalf("did not receive events past %d", len(snapshotEvents)) + } + } + require.Len(snapshotEvents, 2) + require.Equal("foo", snapshotEvents[0].GetServiceHealth().CheckServiceNode.Service.Service) + require.Equal(server.config.NodeName, snapshotEvents[0].GetServiceHealth().CheckServiceNode.Node.Node) + require.True(snapshotEvents[1].GetEndOfSnapshot()) + + // Update a different token and make sure we don't see an event. + arg2 := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + Description: "Ignored token", + Policies: []structs.ACLTokenPolicyLink{ + structs.ACLTokenPolicyLink{ + ID: policy.ID, + }, + }, + Local: false, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + var ignoredToken structs.ACLToken + require.NoError(msgpackrpc.CallWithCodec(codec, "ACL.TokenSet", &arg2, &ignoredToken)) + + select { + case event := <-eventCh: + t.Fatalf("should not have received event: %v", event) + case <-time.After(500 * time.Millisecond): + } + + // Update our token to trigger a refresh event. + token.Policies = []structs.ACLTokenPolicyLink{} + arg := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: token, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + require.NoError(msgpackrpc.CallWithCodec(codec, "ACL.TokenSet", &arg, &token)) + + select { + case event := <-eventCh: + require.True(event.GetResetStream()) + // 500 ms was not enough in CI apparently... + case <-time.After(2 * time.Second): + t.Fatalf("did not receive reload event") + } + } +} + +// testSendEvents receives pbsubscribe.Events from a given handle and sends them to the provided +// channel. This is meant to be run in a separate goroutine from the main test. +func testSendEvents(t *testing.T, ch chan *pbsubscribe.Event, handle pbsubscribe.StateChangeSubscription_SubscribeClient) { + for { + event, err := handle.Recv() + if err == io.EOF { + break + } + if err != nil { + if strings.Contains(err.Error(), "context deadline exceeded") || + strings.Contains(err.Error(), "context canceled") { + break + } + t.Log(err) + } + ch <- event + } +} + +func TestStreaming_TLSEnabled(t *testing.T) { + t.Parallel() + + require := require.New(t) + dir1, conf1 := testServerConfig(t) + conf1.VerifyIncoming = true + conf1.VerifyOutgoing = true + conf1.GRPCEnabled = true + configureTLS(conf1) + server, err := newServer(conf1) + if err != nil { + t.Fatalf("err: %v", err) + } + defer os.RemoveAll(dir1) + defer server.Shutdown() + + dir2, conf2 := testClientConfig(t) + conf2.VerifyOutgoing = true + conf2.GRPCEnabled = true + configureTLS(conf2) + client, err := NewClient(conf2) + if err != nil { + t.Fatalf("err: %v", err) + } + defer os.RemoveAll(dir2) + defer client.Shutdown() + + // Try to join + testrpc.WaitForLeader(t, server.RPC, "dc1") + joinLAN(t, client, server) + testrpc.WaitForTestAgent(t, client.RPC, "dc1") + + // Register a dummy node with our service on it. + { + req := &structs.RegisterRequest{ + Node: "node1", + Address: "3.4.5.6", + Datacenter: "dc1", + Service: &structs.NodeService{ + ID: "redis1", + Service: "redis", + Address: "3.4.5.6", + Port: 8080, + }, + } + var out struct{} + require.NoError(server.RPC("Catalog.Register", &req, &out)) + } + + // Start a Subscribe call to our streaming endpoint from the client. + { + conn, err := client.GRPCConn() + require.NoError(err) + + streamClient := pbsubscribe.NewConsulClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis"}) + require.NoError(err) + + // Start a goroutine to read updates off the pbsubscribe. + eventCh := make(chan *pbsubscribe.Event, 0) + go testSendEvents(t, eventCh, streamHandle) + + var snapshotEvents []*pbsubscribe.Event + for i := 0; i < 2; i++ { + select { + case event := <-eventCh: + snapshotEvents = append(snapshotEvents, event) + case <-time.After(3 * time.Second): + t.Fatalf("did not receive events past %d", len(snapshotEvents)) + } + } + + // Make sure the snapshot events come back with no issues. + require.Len(snapshotEvents, 2) + } + + // Start a Subscribe call to our streaming endpoint from the server's loopback client. + { + conn, err := server.GRPCConn() + require.NoError(err) + + retryFailedConn(t, conn) + + streamClient := pbsubscribe.NewConsulClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis"}) + require.NoError(err) + + // Start a goroutine to read updates off the pbsubscribe. + eventCh := make(chan *pbsubscribe.Event, 0) + go testSendEvents(t, eventCh, streamHandle) + + var snapshotEvents []*pbsubscribe.Event + for i := 0; i < 2; i++ { + select { + case event := <-eventCh: + snapshotEvents = append(snapshotEvents, event) + case <-time.After(3 * time.Second): + t.Fatalf("did not receive events past %d", len(snapshotEvents)) + } + } + + // Make sure the snapshot events come back with no issues. + require.Len(snapshotEvents, 2) + } +} + +func TestStreaming_TLSReload(t *testing.T) { + t.Parallel() + + // Set up a server with initially bad certificates. + require := require.New(t) + dir1, conf1 := testServerConfig(t) + conf1.VerifyIncoming = true + conf1.VerifyOutgoing = true + conf1.CAFile = "../../test/ca/root.cer" + conf1.CertFile = "../../test/key/ssl-cert-snakeoil.pem" + conf1.KeyFile = "../../test/key/ssl-cert-snakeoil.key" + conf1.GRPCEnabled = true + + server, err := newServer(conf1) + if err != nil { + t.Fatalf("err: %v", err) + } + defer os.RemoveAll(dir1) + defer server.Shutdown() + + // Set up a client with valid certs and verify_outgoing = true + dir2, conf2 := testClientConfig(t) + conf2.VerifyOutgoing = true + conf2.GRPCEnabled = true + configureTLS(conf2) + client, err := NewClient(conf2) + if err != nil { + t.Fatalf("err: %v", err) + } + defer os.RemoveAll(dir2) + defer client.Shutdown() + + testrpc.WaitForLeader(t, server.RPC, "dc1") + + // Subscribe calls should fail initially + joinLAN(t, client, server) + conn, err := client.GRPCConn() + require.NoError(err) + { + streamClient := pbsubscribe.NewConsulClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + _, err = streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis"}) + require.Error(err, "tls: bad certificate") + } + + // Reload the server with valid certs + newConf := server.config.ToTLSUtilConfig() + newConf.CertFile = "../../test/key/ourdomain.cer" + newConf.KeyFile = "../../test/key/ourdomain.key" + server.tlsConfigurator.Update(newConf) + + // Try the subscribe call again + { + retryFailedConn(t, conn) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + streamClient := pbsubscribe.NewConsulClient(conn) + _, err = streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis"}) + require.NoError(err) + } +} + +// retryFailedConn forces the ClientConn to reset its backoff timer and retry the connection, +// to simulate the client eventually retrying after the initial failure. This is used both to simulate +// retrying after an expected failure as well as to avoid flakiness when running many tests in parallel. +func retryFailedConn(t *testing.T, conn *grpc.ClientConn) { + state := conn.GetState() + if state.String() != "TRANSIENT_FAILURE" { + return + } + + // If the connection has failed, retry and wait for a state change. + conn.ResetConnectBackoff() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + require.True(t, conn.WaitForStateChange(ctx, state)) +} + +func TestStreaming_DeliversAllMessages(t *testing.T) { + // This is a fuzz/probabilistic test to try to provoke streaming into dropping + // messages. There is a bug in the initial implementation that should make + // this fail. While we can't be certain a pass means it's correct, it is + // useful for finding bugs in our concurrency design. + + // The issue is that when updates are coming in fast such that updates occur + // in between us making the snapshot and beginning the stream updates, we + // shouldn't miss anything. + + // To test this, we will run a background goroutine that will write updates as + // fast as possible while we then try to stream the results and ensure that we + // see every change. We'll make the updates monotonically increasing so we can + // easily tell if we missed one. + + require := require.New(t) + dir1, server := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.Bootstrap = true + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir1) + defer server.Shutdown() + codec := rpcClient(t, server) + defer codec.Close() + + dir2, client := testClientWithConfig(t, func(c *Config) { + c.Datacenter = "dc1" + c.NodeName = uniqueNodeName(t.Name()) + c.GRPCEnabled = true + }) + defer os.RemoveAll(dir2) + defer client.Shutdown() + + // Try to join + testrpc.WaitForLeader(t, server.RPC, "dc1") + joinLAN(t, client, server) + testrpc.WaitForTestAgent(t, client.RPC, "dc1") + + // Register a whole bunch of service instances so that the initial snapshot on + // subscribe is big enough to take a bit of time to load giving more + // opportunity for missed updates if there is a bug. + for i := 0; i < 1000; i++ { + req := &structs.RegisterRequest{ + Node: fmt.Sprintf("node-redis-%03d", i), + Address: "3.4.5.6", + Datacenter: "dc1", + Service: &structs.NodeService{ + ID: fmt.Sprintf("redis-%03d", i), + Service: "redis", + Port: 11211, + }, + } + var out struct{} + require.NoError(server.RPC("Catalog.Register", &req, &out)) + } + + // Start background writer + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go func() { + // Update the registration with a monotonically increasing port as fast as + // we can. + req := &structs.RegisterRequest{ + Node: "node1", + Address: "3.4.5.6", + Datacenter: "dc1", + Service: &structs.NodeService{ + ID: "redis-canary", + Service: "redis", + Port: 0, + }, + } + for { + if ctx.Err() != nil { + return + } + var out struct{} + require.NoError(server.RPC("Catalog.Register", &req, &out)) + req.Service.Port++ + if req.Service.Port > 100 { + return + } + time.Sleep(1 * time.Millisecond) + } + }() + + // Now start a whole bunch of streamers in parallel to maximise chance of + // catching a race. + conn, err := client.GRPCConn() + require.NoError(err) + + streamClient := pbsubscribe.NewConsulClient(conn) + + n := 5 + var wg sync.WaitGroup + var updateCount uint64 + // Buffered error chan so that workers can exit and terminate wg without + // blocking on send. We collect errors this way since t isn't thread safe. + errCh := make(chan error, n) + for i := 0; i < n; i++ { + wg.Add(1) + go verifyMonotonicStreamUpdates(ctx, t, streamClient, &wg, i, &updateCount, errCh) + } + + // Wait until all subscribers have verified the first bunch of updates all got + // delivered. + wg.Wait() + + close(errCh) + + // Require that none of them errored. Since we closed the chan above this loop + // should terminate immediately if no errors were buffered. + for err := range errCh { + require.NoError(err) + } + + // Sanity check that at least some non-snapshot messages were delivered. We + // can't know exactly how many because it's timing dependent based on when + // each subscribers snapshot occurs. + require.True(atomic.LoadUint64(&updateCount) > 0, + "at least some of the subscribers should have received non-snapshot updates") +} + +type testLogger interface { + Logf(format string, args ...interface{}) +} + +func verifyMonotonicStreamUpdates(ctx context.Context, logger testLogger, client pbsubscribe.StateChangeSubscriptionClient, wg *sync.WaitGroup, i int, updateCount *uint64, errCh chan<- error) { + defer wg.Done() + streamHandle, err := client.Subscribe(ctx, &pbsubscribe.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis"}) + if err != nil { + if strings.Contains(err.Error(), "context deadline exceeded") || + strings.Contains(err.Error(), "context canceled") { + logger.Logf("subscriber %05d: context cancelled before loop") + return + } + errCh <- err + return + } + + snapshotDone := false + expectPort := 0 + for { + event, err := streamHandle.Recv() + if err == io.EOF { + break + } + if err != nil { + if strings.Contains(err.Error(), "context deadline exceeded") || + strings.Contains(err.Error(), "context canceled") { + break + } + errCh <- err + return + } + + // Ignore snapshot message + if event.GetEndOfSnapshot() || event.GetResumeStream() { + snapshotDone = true + logger.Logf("subscriber %05d: snapshot done, expect next port to be %d", i, expectPort) + } else if snapshotDone { + // Verify we get all updates in order + svc, err := svcOrErr(event) + if err != nil { + errCh <- err + return + } + if expectPort != svc.Port { + errCh <- fmt.Errorf("subscriber %05d: missed %d update(s)!", i, svc.Port-expectPort) + return + } + atomic.AddUint64(updateCount, 1) + logger.Logf("subscriber %05d: got event with correct port=%d", i, expectPort) + expectPort++ + } else { + // This is a snapshot update. Check if it's an update for the canary + // instance that got applied before our snapshot was sent (likely) + svc, err := svcOrErr(event) + if err != nil { + errCh <- err + return + } + if svc.ID == "redis-canary" { + // Update the expected port we see in the next update to be one more + // than the port in the snapshot. + expectPort = svc.Port + 1 + logger.Logf("subscriber %05d: saw canary in snapshot with port %d", i, svc.Port) + } + } + if expectPort > 100 { + return + } + } +} + +func svcOrErr(event *pbsubscribe.Event) (*pbservice.NodeService, error) { + health := event.GetServiceHealth() + if health == nil { + return nil, fmt.Errorf("not a health event: %#v", event) + } + csn := health.CheckServiceNode + if csn == nil { + return nil, fmt.Errorf("nil CSN: %#v", event) + } + if csn.Service == nil { + return nil, fmt.Errorf("nil service: %#v", event) + } + return csn.Service, nil +} + +// requireEqualProtos is a helper that runs arrays or structures containing +// proto buf messages through JSON encoding before comparing/diffing them. This +// is necessary because require.Equal doesn't compare them equal and generates +// really unhelpful output in this case for some reason. +func requireEqualProtos(t *testing.T, want, got interface{}) { + t.Helper() + gotJSON, err := json.Marshal(got) + require.NoError(t, err) + expectJSON, err := json.Marshal(want) + require.NoError(t, err) + require.JSONEq(t, string(expectJSON), string(gotJSON)) +} +*/ From d0256a0c079225d01b64651425c784198bf679b1 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Tue, 8 Sep 2020 17:31:47 -0400 Subject: [PATCH 03/11] subscribe: add a stateless subscribe service for the gRPC server With a Backend that provides access to the necessary dependencies. --- agent/consul/options.go | 2 + agent/consul/server.go | 19 +- agent/consul/stream/subscription.go | 2 +- agent/consul/subscribe_backend.go | 42 ++++ agent/grpc/handler.go | 7 +- agent/grpc/server_test.go | 7 +- agent/grpc/stats_test.go | 5 +- agent/setup.go | 3 + agent/subscribe/auth.go | 40 ++++ agent/subscribe/subscribe.go | 294 +++++++++++++++++----------- 10 files changed, 291 insertions(+), 130 deletions(-) create mode 100644 agent/consul/subscribe_backend.go create mode 100644 agent/subscribe/auth.go diff --git a/agent/consul/options.go b/agent/consul/options.go index 242da3d35f..12507cb855 100644 --- a/agent/consul/options.go +++ b/agent/consul/options.go @@ -1,6 +1,7 @@ package consul import ( + "github.com/hashicorp/consul/agent/grpc" "github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/router" "github.com/hashicorp/consul/agent/token" @@ -14,4 +15,5 @@ type Deps struct { Tokens *token.Store Router *router.Router ConnPool *pool.ConnPool + GRPCConnPool *grpc.ClientConnPool } diff --git a/agent/consul/server.go b/agent/consul/server.go index 57f799472c..a8bbfde217 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -26,14 +26,16 @@ import ( "github.com/hashicorp/consul/agent/consul/fsm" "github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/consul/usagemetrics" - "github.com/hashicorp/consul/agent/grpc" + agentgrpc "github.com/hashicorp/consul/agent/grpc" "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/router" "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/agent/subscribe" "github.com/hashicorp/consul/agent/token" "github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/logging" + "github.com/hashicorp/consul/proto/pbsubscribe" "github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/consul/types" connlimit "github.com/hashicorp/go-connlimit" @@ -44,6 +46,7 @@ import ( raftboltdb "github.com/hashicorp/raft-boltdb" "github.com/hashicorp/serf/serf" "golang.org/x/time/rate" + "google.golang.org/grpc" ) // These are the protocol versions that Consul can _understand_. These are @@ -577,7 +580,7 @@ func NewServer(config *Config, flat Deps) (*Server, error) { } go reporter.Run(&lib.StopChannelContext{StopCh: s.shutdownCh}) - s.grpcHandler = newGRPCHandlerFromConfig(logger, config) + s.grpcHandler = newGRPCHandlerFromConfig(flat, config, s) // Initialize Autopilot. This must happen before starting leadership monitoring // as establishing leadership could attempt to use autopilot and cause a panic. @@ -606,12 +609,18 @@ func NewServer(config *Config, flat Deps) (*Server, error) { return s, nil } -func newGRPCHandlerFromConfig(logger hclog.Logger, config *Config) connHandler { +func newGRPCHandlerFromConfig(deps Deps, config *Config, s *Server) connHandler { if !config.EnableGRPCServer { - return grpc.NoOpHandler{Logger: logger} + return agentgrpc.NoOpHandler{Logger: deps.Logger} } - return grpc.NewHandler(config.RPCAddr) + register := func(srv *grpc.Server) { + pbsubscribe.RegisterStateChangeSubscriptionServer(srv, &subscribe.Server{ + Backend: &subscribeBackend{srv: s, connPool: deps.GRPCConnPool}, + Logger: deps.Logger.Named("grpc-api.subscription"), + }) + } + return agentgrpc.NewHandler(config.RPCAddr, register) } func (s *Server) connectCARootsMonitor(ctx context.Context) { diff --git a/agent/consul/stream/subscription.go b/agent/consul/stream/subscription.go index 5b86d48133..aa71d3f612 100644 --- a/agent/consul/stream/subscription.go +++ b/agent/consul/stream/subscription.go @@ -19,7 +19,7 @@ const ( // ErrSubscriptionClosed is a error signalling the subscription has been // closed. The client should Unsubscribe, then re-Subscribe. -var ErrSubscriptionClosed = errors.New("subscription closed by server, client should resubscribe") +var ErrSubscriptionClosed = errors.New("subscription closed by server, client must reset state and resubscribe") // Subscription provides events on a Topic. Events may be filtered by Key. // Events are returned by Next(), and may start with a Snapshot of events. diff --git a/agent/consul/subscribe_backend.go b/agent/consul/subscribe_backend.go new file mode 100644 index 0000000000..e7c1ca90a8 --- /dev/null +++ b/agent/consul/subscribe_backend.go @@ -0,0 +1,42 @@ +package consul + +import ( + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/stream" + agentgrpc "github.com/hashicorp/consul/agent/grpc" + "github.com/hashicorp/consul/agent/subscribe" + "google.golang.org/grpc" +) + +type subscribeBackend struct { + srv *Server + connPool *agentgrpc.ClientConnPool +} + +// TODO: refactor Resolve methods to an ACLBackend that can be used by all +// the endpoints. +func (s subscribeBackend) ResolveToken(token string) (acl.Authorizer, error) { + return s.srv.ResolveToken(token) +} + +var _ subscribe.Backend = (*subscribeBackend)(nil) + +// Forward requests to a remote datacenter by calling f if the target dc does not +// match the config. Does nothing but return handled=false if dc is not specified, +// or if it matches the Datacenter in config. +// +// TODO: extract this so that it can be used with other grpc services. +func (s subscribeBackend) Forward(dc string, f func(*grpc.ClientConn) error) (handled bool, err error) { + if dc == "" || dc == s.srv.config.Datacenter { + return false, nil + } + conn, err := s.connPool.ClientConn(dc) + if err != nil { + return false, err + } + return true, f(conn) +} + +func (s subscribeBackend) Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error) { + return s.srv.fsm.State().EventPublisher().Subscribe(req) +} diff --git a/agent/grpc/handler.go b/agent/grpc/handler.go index c43c1ba1e2..d70fd2b10c 100644 --- a/agent/grpc/handler.go +++ b/agent/grpc/handler.go @@ -11,15 +11,16 @@ import ( ) // NewHandler returns a gRPC server that accepts connections from Handle(conn). -func NewHandler(addr net.Addr) *Handler { +// The register function will be called with the grpc.Server to register +// gRPC services with the server. +func NewHandler(addr net.Addr, register func(server *grpc.Server)) *Handler { // We don't need to pass tls.Config to the server since it's multiplexed // behind the RPC listener, which already has TLS configured. srv := grpc.NewServer( grpc.StatsHandler(newStatsHandler()), grpc.StreamInterceptor((&activeStreamCounter{}).Intercept), ) - - // TODO(streaming): add gRPC services to srv here + register(srv) lis := &chanListener{addr: addr, conns: make(chan net.Conn)} return &Handler{srv: srv, listener: lis} diff --git a/agent/grpc/server_test.go b/agent/grpc/server_test.go index b4cb9c7834..68417354bc 100644 --- a/agent/grpc/server_test.go +++ b/agent/grpc/server_test.go @@ -13,6 +13,7 @@ import ( "github.com/hashicorp/consul/agent/pool" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" + "google.golang.org/grpc" ) type testServer struct { @@ -28,9 +29,9 @@ func (s testServer) Metadata() *metadata.Server { func newTestServer(t *testing.T, name string, dc string) testServer { addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} - handler := NewHandler(addr) - - testservice.RegisterSimpleServer(handler.srv, &simple{name: name, dc: dc}) + handler := NewHandler(addr, func(server *grpc.Server) { + testservice.RegisterSimpleServer(server, &simple{name: name, dc: dc}) + }) lis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) diff --git a/agent/grpc/stats_test.go b/agent/grpc/stats_test.go index 05a9f30365..e913181239 100644 --- a/agent/grpc/stats_test.go +++ b/agent/grpc/stats_test.go @@ -14,11 +14,14 @@ import ( "google.golang.org/grpc" ) +func noopRegister(*grpc.Server) {} + func TestHandler_EmitsStats(t *testing.T) { sink := patchGlobalMetrics(t) addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} - handler := NewHandler(addr) + handler := NewHandler(addr, noopRegister) + testservice.RegisterSimpleServer(handler.srv, &simple{}) lis, err := net.Listen("tcp", "127.0.0.1:0") diff --git a/agent/setup.go b/agent/setup.go index 454bfa510d..213ef304ea 100644 --- a/agent/setup.go +++ b/agent/setup.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/consul" + "github.com/hashicorp/consul/agent/grpc" "github.com/hashicorp/consul/agent/grpc/resolver" "github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/router" @@ -86,6 +87,8 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer) (BaseDeps, error) // TODO(streaming): setConfig.Scheme name for tests builder := resolver.NewServerResolverBuilder(resolver.Config{}) resolver.RegisterWithGRPC(builder) + d.GRPCConnPool = grpc.NewClientConnPool(builder, grpc.TLSWrapper(d.TLSConfigurator.OutgoingRPCWrapper())) + d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), builder) acConf := autoconf.Config{ diff --git a/agent/subscribe/auth.go b/agent/subscribe/auth.go new file mode 100644 index 0000000000..81ee3e934e --- /dev/null +++ b/agent/subscribe/auth.go @@ -0,0 +1,40 @@ +package subscribe + +import ( + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/state" + "github.com/hashicorp/consul/agent/consul/stream" +) + +// EnforceACL takes an acl.Authorizer and returns the decision for whether the +// event is allowed to be sent to this client or not. +func enforceACL(authz acl.Authorizer, e stream.Event) acl.EnforcementDecision { + switch { + case e.IsEndOfSnapshot(), e.IsEndOfEmptySnapshot(): + return acl.Allow + } + + switch p := e.Payload.(type) { + case state.EventPayloadCheckServiceNode: + csn := p.Value + if csn.Node == nil || csn.Service == nil || csn.Node.Node == "" || csn.Service.Service == "" { + return acl.Deny + } + + // TODO: what about acl.Default? + // TODO(streaming): we need the AuthorizerContext for ent + if dec := authz.NodeRead(csn.Node.Node, nil); dec != acl.Allow { + return acl.Deny + } + + // TODO(streaming): we need the AuthorizerContext for ent + // Enterprise support for streaming events - they don't have enough data to + // populate it yet. + if dec := authz.ServiceRead(csn.Service.Service, nil); dec != acl.Allow { + return acl.Deny + } + return acl.Allow + } + + return acl.Deny +} diff --git a/agent/subscribe/subscribe.go b/agent/subscribe/subscribe.go index d2d7083889..3fb2ae2672 100644 --- a/agent/subscribe/subscribe.go +++ b/agent/subscribe/subscribe.go @@ -1,45 +1,66 @@ package subscribe import ( - "github.com/hashicorp/consul/agent/consul/stream" - "github.com/hashicorp/consul/proto/pbevent" - "github.com/hashicorp/go-hclog" + "errors" + "fmt" + + "github.com/hashicorp/consul/agent/consul/state" + "github.com/hashicorp/consul/proto/pbservice" "github.com/hashicorp/go-uuid" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/stream" + "github.com/hashicorp/consul/proto/pbsubscribe" ) // Server implements a StateChangeSubscriptionServer for accepting SubscribeRequests, // and sending events to the subscription topic. type Server struct { - srv *Server - logger hclog.Logger + Backend Backend + Logger Logger } -var _ pbevent.StateChangeSubscriptionServer = (*Server)(nil) +type Logger interface { + IsTrace() bool + Trace(msg string, args ...interface{}) +} -func (h *Server) Subscribe(req *pbevent.SubscribeRequest, serverStream pbevent.StateChangeSubscription_SubscribeServer) error { +var _ pbsubscribe.StateChangeSubscriptionServer = (*Server)(nil) + +type Backend interface { + ResolveToken(token string) (acl.Authorizer, error) + Forward(dc string, f func(*grpc.ClientConn) error) (handled bool, err error) + Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error) +} + +func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsubscribe.StateChangeSubscription_SubscribeServer) error { // streamID is just used for message correlation in trace logs and not // populated normally. var streamID string - var err error - if h.logger.IsTrace() { + if h.Logger.IsTrace() { // TODO(banks) it might be nice one day to replace this with OpenTracing ID // if one is set etc. but probably pointless until we support that properly // in other places so it's actually propagated properly. For now this just // makes lifetime of a stream more traceable in our regular server logs for // debugging/dev. + var err error streamID, err = uuid.GenerateUUID() if err != nil { return err } } - // Forward the request to a remote DC if applicable. - if req.Datacenter != "" && req.Datacenter != h.srv.config.Datacenter { - return h.forwardAndProxy(req, serverStream, streamID) + // TODO: add fields to logger and pass logger around instead of streamID + handled, err := h.Backend.Forward(req.Datacenter, h.forwardToDC(req, serverStream, streamID)) + if handled || err != nil { + return err } - h.srv.logger.Trace("new subscription", + h.Logger.Trace("new subscription", "topic", req.Topic.String(), "key", req.Key, "index", req.Index, @@ -47,141 +68,180 @@ func (h *Server) Subscribe(req *pbevent.SubscribeRequest, serverStream pbevent.S ) var sentCount uint64 - defer h.srv.logger.Trace("subscription closed", "stream_id", streamID) + defer h.Logger.Trace("subscription closed", "stream_id", streamID) // Resolve the token and create the ACL filter. // TODO: handle token expiry gracefully... - authz, err := h.srv.ResolveToken(req.Token) + authz, err := h.Backend.ResolveToken(req.Token) if err != nil { return err } - aclFilter := newACLFilter(authz, h.srv.logger, h.srv.config.ACLEnforceVersion8) - state := h.srv.fsm.State() - - // Register a subscription on this topic/key with the FSM. - sub, err := state.Subscribe(serverStream.Context(), req) + sub, err := h.Backend.Subscribe(toStreamSubscribeRequest(req)) if err != nil { return err } - defer state.Unsubscribe(req) + defer sub.Unsubscribe() - // Deliver the events + ctx := serverStream.Context() + snapshotDone := false for { - events, err := sub.Next() - if err == stream.ErrSubscriptionReload { - event := pbevent.Event{ - Payload: &pbevent.Event_ResetStream{ResetStream: true}, - } - if err := serverStream.Send(&event); err != nil { - return err - } - h.srv.logger.Trace("subscription reloaded", - "stream_id", streamID, - ) - return nil - } - if err != nil { + events, err := sub.Next(ctx) + switch { + // TODO: test case + case errors.Is(err, stream.ErrSubscriptionClosed): + h.Logger.Trace("subscription reset by server", "stream_id", streamID) + return status.Error(codes.Aborted, err.Error()) + case err != nil: return err } - aclFilter.filterStreamEvents(&events) + events = filterStreamEvents(authz, events) + if len(events) == 0 { + continue + } - snapshotDone := false - if len(events) == 1 { - if events[0].GetEndOfSnapshot() { - snapshotDone = true - h.srv.logger.Trace("snapshot complete", - "index", events[0].Index, - "sent", sentCount, - "stream_id", streamID, - ) - } else if events[0].GetResumeStream() { - snapshotDone = true - h.srv.logger.Trace("resuming stream", - "index", events[0].Index, - "sent", sentCount, - "stream_id", streamID, - ) - } else if snapshotDone { - // Count this event too in the normal case as "sent" the above cases - // only show the number of events sent _before_ the snapshot ended. - h.srv.logger.Trace("sending events", - "index", events[0].Index, - "sent", sentCount, - "batch_size", 1, - "stream_id", streamID, - ) - } - sentCount++ - if err := serverStream.Send(&events[0]); err != nil { - return err - } - } else if len(events) > 1 { - e := &pbevent.Event{ - Topic: req.Topic, - Key: req.Key, - Index: events[0].Index, - Payload: &pbevent.Event_EventBatch{ - EventBatch: &pbevent.EventBatch{ - Events: pbevent.EventBatchEventsFromEventSlice(events), - }, - }, - } - sentCount += uint64(len(events)) - h.srv.logger.Trace("sending events", - "index", events[0].Index, + first := events[0] + switch { + case first.IsEndOfSnapshot() || first.IsEndOfEmptySnapshot(): + snapshotDone = true + h.Logger.Trace("snapshot complete", + "index", first.Index, "sent", sentCount, "stream_id", streamID) + case snapshotDone: + h.Logger.Trace("sending events", + "index", first.Index, "sent", sentCount, "batch_size", len(events), "stream_id", streamID, ) - if err := serverStream.Send(e); err != nil { + } + + sentCount += uint64(len(events)) + e := newEventFromStreamEvents(req, events) + if err := serverStream.Send(e); err != nil { + return err + } + } +} + +// TODO: can be replaced by mog conversion +func toStreamSubscribeRequest(req *pbsubscribe.SubscribeRequest) *stream.SubscribeRequest { + return &stream.SubscribeRequest{ + Topic: req.Topic, + Key: req.Key, + Token: req.Token, + Index: req.Index, + } +} + +func (h *Server) forwardToDC( + req *pbsubscribe.SubscribeRequest, + serverStream pbsubscribe.StateChangeSubscription_SubscribeServer, + streamID string, +) func(conn *grpc.ClientConn) error { + return func(conn *grpc.ClientConn) error { + h.Logger.Trace("forwarding to another DC", + "dc", req.Datacenter, + "topic", req.Topic.String(), + "key", req.Key, + "index", req.Index, + "stream_id", streamID, + ) + + defer func() { + h.Logger.Trace("forwarded stream closed", + "dc", req.Datacenter, + "stream_id", streamID, + ) + }() + + client := pbsubscribe.NewStateChangeSubscriptionClient(conn) + streamHandle, err := client.Subscribe(serverStream.Context(), req) + if err != nil { + return err + } + + for { + event, err := streamHandle.Recv() + if err != nil { + return err + } + if err := serverStream.Send(event); err != nil { return err } } } } -func (h *Server) forwardAndProxy( - req *pbevent.SubscribeRequest, - serverStream pbevent.StateChangeSubscription_SubscribeServer, - streamID string) error { - - conn, err := h.srv.grpcClient.GRPCConn(req.Datacenter) - if err != nil { - return err +// filterStreamEvents to only those allowed by the acl token. +func filterStreamEvents(authz acl.Authorizer, events []stream.Event) []stream.Event { + // TODO: when is authz nil? + if authz == nil || len(events) == 0 { + return events } - h.logger.Trace("forwarding to another DC", - "dc", req.Datacenter, - "topic", req.Topic.String(), - "key", req.Key, - "index", req.Index, - "stream_id", streamID, - ) - - defer func() { - h.logger.Trace("forwarded stream closed", - "dc", req.Datacenter, - "stream_id", streamID, - ) - }() - - // Open a Subscribe call to the remote DC. - client := pbevent.NewConsulClient(conn) - streamHandle, err := client.Subscribe(serverStream.Context(), req) - if err != nil { - return err + // Fast path for the common case of only 1 event since we can avoid slice + // allocation in the hot path of every single update event delivered in vast + // majority of cases with this. Note that this is called _per event/item_ when + // sending snapshots which is a lot worse than being called once on regular + // result. + if len(events) == 1 { + if enforceACL(authz, events[0]) == acl.Allow { + return events + } + return nil } - // Relay the events back to the client. - for { - event, err := streamHandle.Recv() - if err != nil { - return err + var filtered []stream.Event + for idx := range events { + event := events[idx] + if enforceACL(authz, event) == acl.Allow { + filtered = append(filtered, event) } - if err := serverStream.Send(event); err != nil { - return err + } + return filtered +} + +func newEventFromStreamEvents(req *pbsubscribe.SubscribeRequest, events []stream.Event) *pbsubscribe.Event { + e := &pbsubscribe.Event{ + Topic: req.Topic, + Key: req.Key, + Index: events[0].Index, + } + if len(events) == 1 { + setPayload(e, events[0].Payload) + return e + } + + e.Payload = &pbsubscribe.Event_EventBatch{ + EventBatch: &pbsubscribe.EventBatch{ + Events: batchEventsFromEventSlice(events), + }, + } + return e +} + +func setPayload(e *pbsubscribe.Event, payload interface{}) { + switch p := payload.(type) { + case state.EventPayloadCheckServiceNode: + e.Payload = &pbsubscribe.Event_ServiceHealth{ + ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ + Op: p.Op, + // TODO: this could be cached + CheckServiceNode: pbservice.NewCheckServiceNodeFromStructs(p.Value), + }, } + default: + panic(fmt.Sprintf("unexpected payload: %T: %#v", p, p)) } } + +func batchEventsFromEventSlice(events []stream.Event) []*pbsubscribe.Event { + result := make([]*pbsubscribe.Event, len(events)) + for i := range events { + event := events[i] + result[i] = &pbsubscribe.Event{Key: event.Key, Index: event.Index} + setPayload(result[i], event.Payload) + } + return result +} From 013ababda497cb58ab84373f78ebae040cea8d72 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Fri, 25 Sep 2020 19:40:10 -0400 Subject: [PATCH 04/11] subscribe: add first integration test for Server --- agent/subscribe/subscribe.go | 19 +- agent/subscribe/subscribe_test.go | 437 ++++++++++++++++++------------ 2 files changed, 275 insertions(+), 181 deletions(-) diff --git a/agent/subscribe/subscribe.go b/agent/subscribe/subscribe.go index 3fb2ae2672..7908a410b3 100644 --- a/agent/subscribe/subscribe.go +++ b/agent/subscribe/subscribe.go @@ -4,15 +4,15 @@ import ( "errors" "fmt" - "github.com/hashicorp/consul/agent/consul/state" - "github.com/hashicorp/consul/proto/pbservice" "github.com/hashicorp/go-uuid" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/consul/stream" + "github.com/hashicorp/consul/proto/pbservice" "github.com/hashicorp/consul/proto/pbsubscribe" ) @@ -107,6 +107,7 @@ func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsub snapshotDone = true h.Logger.Trace("snapshot complete", "index", first.Index, "sent", sentCount, "stream_id", streamID) + case snapshotDone: h.Logger.Trace("sending events", "index", first.Index, @@ -208,8 +209,20 @@ func newEventFromStreamEvents(req *pbsubscribe.SubscribeRequest, events []stream Key: req.Key, Index: events[0].Index, } + if len(events) == 1 { - setPayload(e, events[0].Payload) + event := events[0] + // TODO: refactor so these are only checked once, instead of 3 times. + switch { + case event.IsEndOfSnapshot(): + e.Payload = &pbsubscribe.Event_EndOfSnapshot{EndOfSnapshot: true} + return e + case event.IsEndOfEmptySnapshot(): + e.Payload = &pbsubscribe.Event_EndOfEmptySnapshot{EndOfEmptySnapshot: true} + return e + } + + setPayload(e, event.Payload) return e } diff --git a/agent/subscribe/subscribe_test.go b/agent/subscribe/subscribe_test.go index 3c490bd323..bac2ff2303 100644 --- a/agent/subscribe/subscribe_test.go +++ b/agent/subscribe/subscribe_test.go @@ -1,35 +1,42 @@ package subscribe -/* TODO -func TestStreaming_Subscribe(t *testing.T) { - t.Parallel() +import ( + "context" + "errors" + "io" + "net" + "testing" + "time" - require := require.New(t) - dir1, server := testServerWithConfig(t, func(c *Config) { - c.Datacenter = "dc1" - c.Bootstrap = true - c.GRPCEnabled = true - }) - defer os.RemoveAll(dir1) - defer server.Shutdown() - codec := rpcClient(t, server) - defer codec.Close() + "github.com/google/go-cmp/cmp" + "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + gogrpc "google.golang.org/grpc" - dir2, client := testClientWithConfig(t, func(c *Config) { - c.Datacenter = "dc1" - c.NodeName = uniqueNodeName(t.Name()) - c.GRPCEnabled = true - }) - defer os.RemoveAll(dir2) - defer client.Shutdown() + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/state" + "github.com/hashicorp/consul/agent/consul/stream" + "github.com/hashicorp/consul/agent/grpc" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/proto/pbcommon" + "github.com/hashicorp/consul/proto/pbservice" + "github.com/hashicorp/consul/proto/pbsubscribe" + "github.com/hashicorp/consul/types" +) - // Try to join - testrpc.WaitForLeader(t, server.RPC, "dc1") - joinLAN(t, client, server) - testrpc.WaitForTestAgent(t, client.RPC, "dc1") +func TestServer_Subscribe_IntegrationWithBackend(t *testing.T) { + if testing.Short() { + t.Skip("too slow for -short run") + } + + backend, err := newTestBackend() + require.NoError(t, err) + srv := &Server{Backend: backend, Logger: hclog.New(nil)} + + addr := newTestServer(t, srv) + ids := newCounter() - // Register a dummy node with a service we don't care about, to make sure - // we don't see updates for it. { req := &structs.RegisterRequest{ Node: "other", @@ -42,11 +49,8 @@ func TestStreaming_Subscribe(t *testing.T) { Port: 9000, }, } - var out struct{} - require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + require.NoError(t, backend.store.EnsureRegistration(ids.Next("other"), req)) } - - // Register a dummy node with our service on it. { req := &structs.RegisterRequest{ Node: "node1", @@ -59,11 +63,8 @@ func TestStreaming_Subscribe(t *testing.T) { Port: 8080, }, } - var out struct{} - require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + require.NoError(t, backend.store.EnsureRegistration(ids.Next("reg2"), req)) } - - // Register a test node to be updated later. req := &structs.RegisterRequest{ Node: "node2", Address: "1.2.3.4", @@ -75,62 +76,58 @@ func TestStreaming_Subscribe(t *testing.T) { Port: 8080, }, } - var out struct{} - require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + require.NoError(t, backend.store.EnsureRegistration(ids.Next("reg3"), req)) - // Start a Subscribe call to our streaming endpoint. - conn, err := client.GRPCConn() - require.NoError(err) - - streamClient := pbsubscribe.NewConsulClient(conn) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() + conn, err := gogrpc.DialContext(ctx, addr.String(), gogrpc.WithInsecure()) + require.NoError(t, err) + t.Cleanup(logError(t, conn.Close)) + + streamClient := pbsubscribe.NewStateChangeSubscriptionClient(conn) streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{ Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis", }) - require.NoError(err) + require.NoError(t, err) - // Start a goroutine to read updates off the pbsubscribe. - eventCh := make(chan *pbsubscribe.Event, 0) - go testSendEvents(t, eventCh, streamHandle) + chEvents := make(chan eventOrError, 0) + go recvEvents(chEvents, streamHandle) var snapshotEvents []*pbsubscribe.Event for i := 0; i < 3; i++ { - select { - case event := <-eventCh: - snapshotEvents = append(snapshotEvents, event) - case <-time.After(3 * time.Second): - t.Fatalf("did not receive events past %d", len(snapshotEvents)) - } + snapshotEvents = append(snapshotEvents, getEvent(t, chEvents)) } expected := []*pbsubscribe.Event{ { Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis", + Index: ids.Last(), Payload: &pbsubscribe.Event_ServiceHealth{ ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ Op: pbsubscribe.CatalogOp_Register, - CheckServiceNode: &pbsubscribe.CheckServiceNode{ - Node: &pbsubscribe.Node{ + CheckServiceNode: &pbservice.CheckServiceNode{ + Node: &pbservice.Node{ Node: "node1", Datacenter: "dc1", Address: "3.4.5.6", + RaftIndex: raftIndex(ids, "reg2", "reg2"), }, - Service: &pbsubscribe.NodeService{ + Service: &pbservice.NodeService{ ID: "redis1", Service: "redis", Address: "3.4.5.6", Port: 8080, - Weights: &pbsubscribe.Weights{Passing: 1, Warning: 1}, + Weights: &pbservice.Weights{Passing: 1, Warning: 1}, // Sad empty state - Proxy: pbsubscribe.ConnectProxyConfig{ - MeshGateway: &pbsubscribe.MeshGatewayConfig{}, - Expose: &pbsubscribe.ExposeConfig{}, + Proxy: pbservice.ConnectProxyConfig{ + MeshGateway: pbservice.MeshGatewayConfig{}, + Expose: pbservice.ExposeConfig{}, }, - EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, + RaftIndex: raftIndex(ids, "reg2", "reg2"), + EnterpriseMeta: pbcommon.EnterpriseMeta{}, }, }, }, @@ -139,27 +136,30 @@ func TestStreaming_Subscribe(t *testing.T) { { Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis", + Index: ids.Last(), Payload: &pbsubscribe.Event_ServiceHealth{ ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ Op: pbsubscribe.CatalogOp_Register, - CheckServiceNode: &pbsubscribe.CheckServiceNode{ - Node: &pbsubscribe.Node{ + CheckServiceNode: &pbservice.CheckServiceNode{ + Node: &pbservice.Node{ Node: "node2", Datacenter: "dc1", Address: "1.2.3.4", + RaftIndex: raftIndex(ids, "reg3", "reg3"), }, - Service: &pbsubscribe.NodeService{ + Service: &pbservice.NodeService{ ID: "redis1", Service: "redis", Address: "1.1.1.1", Port: 8080, - Weights: &pbsubscribe.Weights{Passing: 1, Warning: 1}, + Weights: &pbservice.Weights{Passing: 1, Warning: 1}, // Sad empty state - Proxy: pbsubscribe.ConnectProxyConfig{ - MeshGateway: &pbsubscribe.MeshGatewayConfig{}, - Expose: &pbsubscribe.ExposeConfig{}, + Proxy: pbservice.ConnectProxyConfig{ + MeshGateway: pbservice.MeshGatewayConfig{}, + Expose: pbservice.ExposeConfig{}, }, - EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, + RaftIndex: raftIndex(ids, "reg3", "reg3"), + EnterpriseMeta: pbcommon.EnterpriseMeta{}, }, }, }, @@ -168,22 +168,11 @@ func TestStreaming_Subscribe(t *testing.T) { { Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis", + Index: ids.Last(), Payload: &pbsubscribe.Event_EndOfSnapshot{EndOfSnapshot: true}, }, } - - require.Len(snapshotEvents, 3) - for i := 0; i < 2; i++ { - // Fix up the index - expected[i].Index = snapshotEvents[i].Index - node := expected[i].GetServiceHealth().CheckServiceNode - node.Node.RaftIndex = snapshotEvents[i].GetServiceHealth().CheckServiceNode.Node.RaftIndex - node.Service.RaftIndex = snapshotEvents[i].GetServiceHealth().CheckServiceNode.Service.RaftIndex - } - // Fix index on snapshot event - expected[2].Index = snapshotEvents[2].Index - - requireEqualProtos(t, expected, snapshotEvents) + assertDeepEqual(t, expected, snapshotEvents) // Update the registration by adding a check. req.Check = &structs.HealthCheck{ @@ -193,73 +182,189 @@ func TestStreaming_Subscribe(t *testing.T) { ServiceName: "redis", Name: "check 1", } - require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + require.NoError(t, backend.store.EnsureRegistration(ids.Next("update"), req)) - // Make sure we get the event for the diff. - select { - case event := <-eventCh: - expected := &pbsubscribe.Event{ - Topic: pbsubscribe.Topic_ServiceHealth, - Key: "redis", - Payload: &pbsubscribe.Event_ServiceHealth{ - ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ - Op: pbsubscribe.CatalogOp_Register, - CheckServiceNode: &pbsubscribe.CheckServiceNode{ - Node: &pbsubscribe.Node{ - Node: "node2", - Datacenter: "dc1", - Address: "1.2.3.4", - RaftIndex: pbsubscribe.RaftIndex{CreateIndex: 13, ModifyIndex: 13}, + event := getEvent(t, chEvents) + expectedEvent := &pbsubscribe.Event{ + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "redis", + Index: ids.Last(), + Payload: &pbsubscribe.Event_ServiceHealth{ + ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ + Op: pbsubscribe.CatalogOp_Register, + CheckServiceNode: &pbservice.CheckServiceNode{ + Node: &pbservice.Node{ + Node: "node2", + Datacenter: "dc1", + Address: "1.2.3.4", + RaftIndex: raftIndex(ids, "reg3", "reg3"), + }, + Service: &pbservice.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + Weights: &pbservice.Weights{Passing: 1, Warning: 1}, + // Sad empty state + Proxy: pbservice.ConnectProxyConfig{ + MeshGateway: pbservice.MeshGatewayConfig{}, + Expose: pbservice.ExposeConfig{}, }, - Service: &pbsubscribe.NodeService{ - ID: "redis1", - Service: "redis", - Address: "1.1.1.1", - Port: 8080, - RaftIndex: pbsubscribe.RaftIndex{CreateIndex: 13, ModifyIndex: 13}, - Weights: &pbsubscribe.Weights{Passing: 1, Warning: 1}, - // Sad empty state - Proxy: pbsubscribe.ConnectProxyConfig{ - MeshGateway: &pbsubscribe.MeshGatewayConfig{}, - Expose: &pbsubscribe.ExposeConfig{}, - }, - EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, - }, - Checks: []*pbsubscribe.HealthCheck{ - { - CheckID: "check1", - Name: "check 1", - Node: "node2", - Status: "critical", - ServiceID: "redis1", - ServiceName: "redis", - RaftIndex: pbsubscribe.RaftIndex{CreateIndex: 14, ModifyIndex: 14}, - EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, - }, + RaftIndex: raftIndex(ids, "reg3", "reg3"), + EnterpriseMeta: pbcommon.EnterpriseMeta{}, + }, + Checks: []*pbservice.HealthCheck{ + { + CheckID: "check1", + Name: "check 1", + Node: "node2", + Status: "critical", + ServiceID: "redis1", + ServiceName: "redis", + RaftIndex: raftIndex(ids, "update", "update"), + EnterpriseMeta: pbcommon.EnterpriseMeta{}, }, }, }, }, - } - // Fix up the index - expected.Index = event.Index - node := expected.GetServiceHealth().CheckServiceNode - node.Node.RaftIndex = event.GetServiceHealth().CheckServiceNode.Node.RaftIndex - node.Service.RaftIndex = event.GetServiceHealth().CheckServiceNode.Service.RaftIndex - node.Checks[0].RaftIndex = event.GetServiceHealth().CheckServiceNode.Checks[0].RaftIndex - requireEqualProtos(t, expected, event) - case <-time.After(3 * time.Second): - t.Fatal("never got event") + }, } + assertDeepEqual(t, expectedEvent, event) +} - // Wait and make sure there aren't any more events coming. - select { - case event := <-eventCh: - t.Fatalf("got another event: %v", event) - case <-time.After(500 * time.Millisecond): +type eventOrError struct { + event *pbsubscribe.Event + err error +} + +// recvEvents from handle and sends them to the provided channel. +func recvEvents(ch chan eventOrError, handle pbsubscribe.StateChangeSubscription_SubscribeClient) { + defer close(ch) + for { + event, err := handle.Recv() + switch { + case errors.Is(err, io.EOF): + return + case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): + return + case err != nil: + ch <- eventOrError{err: err} + return + default: + ch <- eventOrError{event: event} + } } } +func getEvent(t *testing.T, ch chan eventOrError) *pbsubscribe.Event { + t.Helper() + select { + case item := <-ch: + require.NoError(t, item.err) + return item.event + case <-time.After(10 * time.Second): + t.Fatalf("timeout waiting on event from server") + } + return nil +} + +func assertDeepEqual(t *testing.T, x, y interface{}) { + t.Helper() + if diff := cmp.Diff(x, y); diff != "" { + t.Fatalf("assertion failed: values are not equal\n--- expected\n+++ actual\n%v", diff) + } +} + +type testBackend struct { + store *state.Store + authorizer acl.Authorizer +} + +func (b testBackend) ResolveToken(_ string) (acl.Authorizer, error) { + return b.authorizer, nil +} + +func (b testBackend) Forward(_ string, _ func(*gogrpc.ClientConn) error) (handled bool, err error) { + return false, nil +} + +func (b testBackend) Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error) { + return b.store.EventPublisher().Subscribe(req) +} + +func newTestBackend() (*testBackend, error) { + gc, err := state.NewTombstoneGC(time.Second, time.Millisecond) + if err != nil { + return nil, err + } + store, err := state.NewStateStore(gc) + if err != nil { + return nil, err + } + return &testBackend{store: store, authorizer: acl.AllowAll()}, nil +} + +var _ Backend = (*testBackend)(nil) + +func newTestServer(t *testing.T, server *Server) net.Addr { + addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} + var grpcServer *gogrpc.Server + handler := grpc.NewHandler(addr, func(srv *gogrpc.Server) { + grpcServer = srv + pbsubscribe.RegisterStateChangeSubscriptionServer(srv, server) + }) + + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(logError(t, lis.Close)) + + go grpcServer.Serve(lis) + g := new(errgroup.Group) + g.Go(func() error { + return grpcServer.Serve(lis) + }) + t.Cleanup(func() { + if err := handler.Shutdown(); err != nil { + t.Logf("grpc server shutdown: %v", err) + } + if err := g.Wait(); err != nil { + t.Logf("grpc server error: %v", err) + } + }) + return lis.Addr() +} + +type counter struct { + value uint64 + labels map[string]uint64 +} + +func (c *counter) Next(label string) uint64 { + c.value++ + c.labels[label] = c.value + return c.value +} + +func (c *counter) For(label string) uint64 { + return c.labels[label] +} + +func (c *counter) Last() uint64 { + return c.value +} + +func newCounter() *counter { + return &counter{labels: make(map[string]uint64)} +} + +func raftIndex(ids *counter, created, modified string) pbcommon.RaftIndex { + return pbcommon.RaftIndex{ + CreateIndex: ids.For(created), + ModifyIndex: ids.For(modified), + } +} + +/* TODO func TestStreaming_Subscribe_MultiDC(t *testing.T) { t.Parallel() @@ -365,7 +470,7 @@ func TestStreaming_Subscribe_MultiDC(t *testing.T) { // Start a goroutine to read updates off the pbsubscribe. eventCh := make(chan *pbsubscribe.Event, 0) - go testSendEvents(t, eventCh, streamHandle) + go recvEvents(t, eventCh, streamHandle) var snapshotEvents []*pbsubscribe.Event for i := 0; i < 3; i++ { @@ -452,7 +557,7 @@ func TestStreaming_Subscribe_MultiDC(t *testing.T) { node.Service.RaftIndex = snapshotEvents[i].GetServiceHealth().CheckServiceNode.Service.RaftIndex } expected[2].Index = snapshotEvents[2].Index - requireEqualProtos(t, expected, snapshotEvents) + assertDeepEqual(t, expected, snapshotEvents) // Update the registration by adding a check. req.Check = &structs.HealthCheck{ @@ -516,7 +621,7 @@ func TestStreaming_Subscribe_MultiDC(t *testing.T) { node.Node.RaftIndex = event.GetServiceHealth().CheckServiceNode.Node.RaftIndex node.Service.RaftIndex = event.GetServiceHealth().CheckServiceNode.Service.RaftIndex node.Checks[0].RaftIndex = event.GetServiceHealth().CheckServiceNode.Checks[0].RaftIndex - requireEqualProtos(t, expected, event) + assertDeepEqual(t, expected, event) case <-time.After(3 * time.Second): t.Fatal("never got event") } @@ -588,7 +693,7 @@ func TestStreaming_Subscribe_SkipSnapshot(t *testing.T) { // Start a goroutine to read updates off the pbsubscribe. eventCh := make(chan *pbsubscribe.Event, 0) - go testSendEvents(t, eventCh, streamHandle) + go recvEvents(t, eventCh, streamHandle) var snapshotEvents []*pbsubscribe.Event for i := 0; i < 2; i++ { @@ -617,7 +722,7 @@ func TestStreaming_Subscribe_SkipSnapshot(t *testing.T) { // Start a goroutine to read updates off the pbsubscribe. eventCh := make(chan *pbsubscribe.Event, 0) - go testSendEvents(t, eventCh, streamHandle) + go recvEvents(t, eventCh, streamHandle) // We should get no snapshot and the first event should be "resume stream" select { @@ -737,7 +842,7 @@ func TestStreaming_Subscribe_FilterACL(t *testing.T) { // Start a goroutine to read updates off the pbsubscribe. eventCh := make(chan *pbsubscribe.Event, 0) - go testSendEvents(t, eventCh, streamHandle) + go recvEvents(t, eventCh, streamHandle) // Read events off the pbsubscribe. We should not see any events for the filtered node. var snapshotEvents []*pbsubscribe.Event @@ -823,7 +928,7 @@ func TestStreaming_Subscribe_FilterACL(t *testing.T) { // Start a goroutine to read updates off the pbsubscribe. eventCh := make(chan *pbsubscribe.Event, 0) - go testSendEvents(t, eventCh, streamHandle) + go recvEvents(t, eventCh, streamHandle) select { case event := <-eventCh: @@ -939,7 +1044,7 @@ func TestStreaming_Subscribe_ACLUpdate(t *testing.T) { // Start a goroutine to read updates off the pbsubscribe. eventCh := make(chan *pbsubscribe.Event, 0) - go testSendEvents(t, eventCh, streamHandle) + go recvEvents(t, eventCh, streamHandle) // Read events off the pbsubscribe. var snapshotEvents []*pbsubscribe.Event @@ -998,25 +1103,6 @@ func TestStreaming_Subscribe_ACLUpdate(t *testing.T) { } } -// testSendEvents receives pbsubscribe.Events from a given handle and sends them to the provided -// channel. This is meant to be run in a separate goroutine from the main test. -func testSendEvents(t *testing.T, ch chan *pbsubscribe.Event, handle pbsubscribe.StateChangeSubscription_SubscribeClient) { - for { - event, err := handle.Recv() - if err == io.EOF { - break - } - if err != nil { - if strings.Contains(err.Error(), "context deadline exceeded") || - strings.Contains(err.Error(), "context canceled") { - break - } - t.Log(err) - } - ch <- event - } -} - func TestStreaming_TLSEnabled(t *testing.T) { t.Parallel() @@ -1079,7 +1165,7 @@ func TestStreaming_TLSEnabled(t *testing.T) { // Start a goroutine to read updates off the pbsubscribe. eventCh := make(chan *pbsubscribe.Event, 0) - go testSendEvents(t, eventCh, streamHandle) + go recvEvents(t, eventCh, streamHandle) var snapshotEvents []*pbsubscribe.Event for i := 0; i < 2; i++ { @@ -1110,7 +1196,7 @@ func TestStreaming_TLSEnabled(t *testing.T) { // Start a goroutine to read updates off the pbsubscribe. eventCh := make(chan *pbsubscribe.Event, 0) - go testSendEvents(t, eventCh, streamHandle) + go recvEvents(t, eventCh, streamHandle) var snapshotEvents []*pbsubscribe.Event for i := 0; i < 2; i++ { @@ -1418,17 +1504,12 @@ func svcOrErr(event *pbsubscribe.Event) (*pbservice.NodeService, error) { } return csn.Service, nil } - -// requireEqualProtos is a helper that runs arrays or structures containing -// proto buf messages through JSON encoding before comparing/diffing them. This -// is necessary because require.Equal doesn't compare them equal and generates -// really unhelpful output in this case for some reason. -func requireEqualProtos(t *testing.T, want, got interface{}) { - t.Helper() - gotJSON, err := json.Marshal(got) - require.NoError(t, err) - expectJSON, err := json.Marshal(want) - require.NoError(t, err) - require.JSONEq(t, string(expectJSON), string(gotJSON)) -} */ + +func logError(t *testing.T, f func() error) func() { + return func() { + if err := f(); err != nil { + t.Logf(err.Error()) + } + } +} From 083f4e8f5714c812545b617ed221f241f48c75dd Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Mon, 28 Sep 2020 15:43:29 -0400 Subject: [PATCH 05/11] subscribe: Add an integration test for forward to DC --- agent/subscribe/subscribe_test.go | 256 ++++++++++++------------------ 1 file changed, 100 insertions(+), 156 deletions(-) diff --git a/agent/subscribe/subscribe_test.go b/agent/subscribe/subscribe_test.go index bac2ff2303..4b660c4022 100644 --- a/agent/subscribe/subscribe_test.go +++ b/agent/subscribe/subscribe_test.go @@ -26,10 +26,6 @@ import ( ) func TestServer_Subscribe_IntegrationWithBackend(t *testing.T) { - if testing.Short() { - t.Skip("too slow for -short run") - } - backend, err := newTestBackend() require.NoError(t, err) srv := &Server{Backend: backend, Logger: hclog.New(nil)} @@ -78,8 +74,8 @@ func TestServer_Subscribe_IntegrationWithBackend(t *testing.T) { } require.NoError(t, backend.store.EnsureRegistration(ids.Next("reg3"), req)) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) conn, err := gogrpc.DialContext(ctx, addr.String(), gogrpc.WithInsecure()) require.NoError(t, err) @@ -276,15 +272,19 @@ func assertDeepEqual(t *testing.T, x, y interface{}) { } type testBackend struct { - store *state.Store - authorizer acl.Authorizer + store *state.Store + authorizer acl.Authorizer + forwardConn *gogrpc.ClientConn } func (b testBackend) ResolveToken(_ string) (acl.Authorizer, error) { return b.authorizer, nil } -func (b testBackend) Forward(_ string, _ func(*gogrpc.ClientConn) error) (handled bool, err error) { +func (b testBackend) Forward(_ string, fn func(*gogrpc.ClientConn) error) (handled bool, err error) { + if b.forwardConn != nil { + return true, fn(b.forwardConn) + } return false, nil } @@ -364,47 +364,25 @@ func raftIndex(ids *counter, created, modified string) pbcommon.RaftIndex { } } -/* TODO -func TestStreaming_Subscribe_MultiDC(t *testing.T) { - t.Parallel() +func TestServer_Subscribe_IntegrationWithBackend_ForwardToDC(t *testing.T) { + backendLocal, err := newTestBackend() + require.NoError(t, err) + addrLocal := newTestServer(t, &Server{Backend: backendLocal, Logger: hclog.New(nil)}) - require := require.New(t) - dir1, server1 := testServerWithConfig(t, func(c *Config) { - c.Datacenter = "dc1" - c.Bootstrap = true - c.GRPCEnabled = true - }) - defer os.RemoveAll(dir1) - defer server1.Shutdown() + backendRemoteDC, err := newTestBackend() + require.NoError(t, err) + srvRemoteDC := &Server{Backend: backendRemoteDC, Logger: hclog.New(nil)} + addrRemoteDC := newTestServer(t, srvRemoteDC) - dir2, server2 := testServerWithConfig(t, func(c *Config) { - c.Datacenter = "dc2" - c.Bootstrap = true - c.GRPCEnabled = true - }) - defer os.RemoveAll(dir2) - defer server2.Shutdown() - codec := rpcClient(t, server2) - defer codec.Close() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) - dir3, client := testClientWithConfig(t, func(c *Config) { - c.Datacenter = "dc1" - c.NodeName = uniqueNodeName(t.Name()) - c.GRPCEnabled = true - }) - defer os.RemoveAll(dir3) - defer client.Shutdown() + connRemoteDC, err := gogrpc.DialContext(ctx, addrRemoteDC.String(), gogrpc.WithInsecure()) + require.NoError(t, err) + t.Cleanup(logError(t, connRemoteDC.Close)) + backendLocal.forwardConn = connRemoteDC - // Join the servers via WAN - joinWAN(t, server2, server1) - testrpc.WaitForLeader(t, server1.RPC, "dc1") - testrpc.WaitForLeader(t, server2.RPC, "dc2") - - joinLAN(t, client, server1) - testrpc.WaitForTestAgent(t, client.RPC, "dc1") - - // Register a dummy node in dc2 with a service we don't care about, - // to make sure we don't see updates for it. + ids := newCounter() { req := &structs.RegisterRequest{ Node: "other", @@ -417,11 +395,8 @@ func TestStreaming_Subscribe_MultiDC(t *testing.T) { Port: 9000, }, } - var out struct{} - require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + require.NoError(t, backendRemoteDC.store.EnsureRegistration(ids.Next("req1"), req)) } - - // Register a dummy node with our service on it, again in dc2. { req := &structs.RegisterRequest{ Node: "node1", @@ -434,11 +409,9 @@ func TestStreaming_Subscribe_MultiDC(t *testing.T) { Port: 8080, }, } - var out struct{} - require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + require.NoError(t, backendRemoteDC.store.EnsureRegistration(ids.Next("reg2"), req)) } - // Register a test node in dc2 to be updated later. req := &structs.RegisterRequest{ Node: "node2", Address: "1.2.3.4", @@ -450,63 +423,56 @@ func TestStreaming_Subscribe_MultiDC(t *testing.T) { Port: 8080, }, } - var out struct{} - require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + require.NoError(t, backendRemoteDC.store.EnsureRegistration(ids.Next("reg3"), req)) - // Start a cross-DC Subscribe call to our streaming endpoint, specifying dc2. - conn, err := client.GRPCConn() - require.NoError(err) - - streamClient := pbsubscribe.NewConsulClient(conn) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() + connLocal, err := gogrpc.DialContext(ctx, addrLocal.String(), gogrpc.WithInsecure()) + require.NoError(t, err) + t.Cleanup(logError(t, connLocal.Close)) + streamClient := pbsubscribe.NewStateChangeSubscriptionClient(connLocal) streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{ Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis", Datacenter: "dc2", }) - require.NoError(err) + require.NoError(t, err) - // Start a goroutine to read updates off the pbsubscribe. - eventCh := make(chan *pbsubscribe.Event, 0) - go recvEvents(t, eventCh, streamHandle) + chEvents := make(chan eventOrError, 0) + go recvEvents(chEvents, streamHandle) var snapshotEvents []*pbsubscribe.Event for i := 0; i < 3; i++ { - select { - case event := <-eventCh: - snapshotEvents = append(snapshotEvents, event) - case <-time.After(3 * time.Second): - t.Fatalf("did not receive events past %d", len(snapshotEvents)) - } + snapshotEvents = append(snapshotEvents, getEvent(t, chEvents)) } expected := []*pbsubscribe.Event{ { Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis", + Index: ids.Last(), Payload: &pbsubscribe.Event_ServiceHealth{ ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ Op: pbsubscribe.CatalogOp_Register, - CheckServiceNode: &pbsubscribe.CheckServiceNode{ - Node: &pbsubscribe.Node{ + CheckServiceNode: &pbservice.CheckServiceNode{ + Node: &pbservice.Node{ Node: "node1", Datacenter: "dc2", Address: "3.4.5.6", + RaftIndex: raftIndex(ids, "reg2", "reg2"), }, - Service: &pbsubscribe.NodeService{ + Service: &pbservice.NodeService{ ID: "redis1", Service: "redis", Address: "3.4.5.6", Port: 8080, - Weights: &pbsubscribe.Weights{Passing: 1, Warning: 1}, + Weights: &pbservice.Weights{Passing: 1, Warning: 1}, // Sad empty state - Proxy: pbsubscribe.ConnectProxyConfig{ - MeshGateway: &pbsubscribe.MeshGatewayConfig{}, - Expose: &pbsubscribe.ExposeConfig{}, + Proxy: pbservice.ConnectProxyConfig{ + MeshGateway: pbservice.MeshGatewayConfig{}, + Expose: pbservice.ExposeConfig{}, }, - EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, + EnterpriseMeta: pbcommon.EnterpriseMeta{}, + RaftIndex: raftIndex(ids, "reg2", "reg2"), }, }, }, @@ -515,27 +481,30 @@ func TestStreaming_Subscribe_MultiDC(t *testing.T) { { Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis", + Index: ids.Last(), Payload: &pbsubscribe.Event_ServiceHealth{ ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ Op: pbsubscribe.CatalogOp_Register, - CheckServiceNode: &pbsubscribe.CheckServiceNode{ - Node: &pbsubscribe.Node{ + CheckServiceNode: &pbservice.CheckServiceNode{ + Node: &pbservice.Node{ Node: "node2", Datacenter: "dc2", Address: "1.2.3.4", + RaftIndex: raftIndex(ids, "reg3", "reg3"), }, - Service: &pbsubscribe.NodeService{ + Service: &pbservice.NodeService{ ID: "redis1", Service: "redis", Address: "1.1.1.1", Port: 8080, - Weights: &pbsubscribe.Weights{Passing: 1, Warning: 1}, + Weights: &pbservice.Weights{Passing: 1, Warning: 1}, // Sad empty state - Proxy: pbsubscribe.ConnectProxyConfig{ - MeshGateway: &pbsubscribe.MeshGatewayConfig{}, - Expose: &pbsubscribe.ExposeConfig{}, + Proxy: pbservice.ConnectProxyConfig{ + MeshGateway: pbservice.MeshGatewayConfig{}, + Expose: pbservice.ExposeConfig{}, }, - EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, + EnterpriseMeta: pbcommon.EnterpriseMeta{}, + RaftIndex: raftIndex(ids, "reg3", "reg3"), }, }, }, @@ -544,19 +513,10 @@ func TestStreaming_Subscribe_MultiDC(t *testing.T) { { Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis", + Index: ids.Last(), Payload: &pbsubscribe.Event_EndOfSnapshot{EndOfSnapshot: true}, }, } - - require.Len(snapshotEvents, 3) - for i := 0; i < 2; i++ { - // Fix up the index - expected[i].Index = snapshotEvents[i].Index - node := expected[i].GetServiceHealth().CheckServiceNode - node.Node.RaftIndex = snapshotEvents[i].GetServiceHealth().CheckServiceNode.Node.RaftIndex - node.Service.RaftIndex = snapshotEvents[i].GetServiceHealth().CheckServiceNode.Service.RaftIndex - } - expected[2].Index = snapshotEvents[2].Index assertDeepEqual(t, expected, snapshotEvents) // Update the registration by adding a check. @@ -567,73 +527,57 @@ func TestStreaming_Subscribe_MultiDC(t *testing.T) { ServiceName: "redis", Name: "check 1", } - require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + require.NoError(t, backendRemoteDC.store.EnsureRegistration(ids.Next("update"), req)) - // Make sure we get the event for the diff. - select { - case event := <-eventCh: - expected := &pbsubscribe.Event{ - Topic: pbsubscribe.Topic_ServiceHealth, - Key: "redis", - Payload: &pbsubscribe.Event_ServiceHealth{ - ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ - Op: pbsubscribe.CatalogOp_Register, - CheckServiceNode: &pbsubscribe.CheckServiceNode{ - Node: &pbsubscribe.Node{ - Node: "node2", - Datacenter: "dc2", - Address: "1.2.3.4", - RaftIndex: pbsubscribe.RaftIndex{CreateIndex: 13, ModifyIndex: 13}, + event := getEvent(t, chEvents) + expectedEvent := &pbsubscribe.Event{ + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "redis", + Index: ids.Last(), + Payload: &pbsubscribe.Event_ServiceHealth{ + ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ + Op: pbsubscribe.CatalogOp_Register, + CheckServiceNode: &pbservice.CheckServiceNode{ + Node: &pbservice.Node{ + Node: "node2", + Datacenter: "dc2", + Address: "1.2.3.4", + RaftIndex: raftIndex(ids, "reg3", "reg3"), + }, + Service: &pbservice.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + RaftIndex: raftIndex(ids, "reg3", "reg3"), + Weights: &pbservice.Weights{Passing: 1, Warning: 1}, + // Sad empty state + Proxy: pbservice.ConnectProxyConfig{ + MeshGateway: pbservice.MeshGatewayConfig{}, + Expose: pbservice.ExposeConfig{}, }, - Service: &pbsubscribe.NodeService{ - ID: "redis1", - Service: "redis", - Address: "1.1.1.1", - Port: 8080, - RaftIndex: pbsubscribe.RaftIndex{CreateIndex: 13, ModifyIndex: 13}, - Weights: &pbsubscribe.Weights{Passing: 1, Warning: 1}, - // Sad empty state - Proxy: pbsubscribe.ConnectProxyConfig{ - MeshGateway: &pbsubscribe.MeshGatewayConfig{}, - Expose: &pbsubscribe.ExposeConfig{}, - }, - EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, - }, - Checks: []*pbsubscribe.HealthCheck{ - { - CheckID: "check1", - Name: "check 1", - Node: "node2", - Status: "critical", - ServiceID: "redis1", - ServiceName: "redis", - RaftIndex: pbsubscribe.RaftIndex{CreateIndex: 14, ModifyIndex: 14}, - EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, - }, + EnterpriseMeta: pbcommon.EnterpriseMeta{}, + }, + Checks: []*pbservice.HealthCheck{ + { + CheckID: "check1", + Name: "check 1", + Node: "node2", + Status: "critical", + ServiceID: "redis1", + ServiceName: "redis", + RaftIndex: raftIndex(ids, "update", "update"), + EnterpriseMeta: pbcommon.EnterpriseMeta{}, }, }, }, }, - } - // Fix up the index - expected.Index = event.Index - node := expected.GetServiceHealth().CheckServiceNode - node.Node.RaftIndex = event.GetServiceHealth().CheckServiceNode.Node.RaftIndex - node.Service.RaftIndex = event.GetServiceHealth().CheckServiceNode.Service.RaftIndex - node.Checks[0].RaftIndex = event.GetServiceHealth().CheckServiceNode.Checks[0].RaftIndex - assertDeepEqual(t, expected, event) - case <-time.After(3 * time.Second): - t.Fatal("never got event") - } - - // Wait and make sure there aren't any more events coming. - select { - case event := <-eventCh: - t.Fatalf("got another event: %v", event) - case <-time.After(500 * time.Millisecond): + }, } + assertDeepEqual(t, expectedEvent, event) } +/* TODO func TestStreaming_Subscribe_SkipSnapshot(t *testing.T) { t.Parallel() From 39beed0af6df34e6445bea745c968d6194ec8e05 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Mon, 28 Sep 2020 17:11:51 -0400 Subject: [PATCH 06/11] subscribe: add integration test for filtering events by acl --- agent/subscribe/subscribe_test.go | 369 +++++++++++------------------- 1 file changed, 128 insertions(+), 241 deletions(-) diff --git a/agent/subscribe/subscribe_test.go b/agent/subscribe/subscribe_test.go index 4b660c4022..bacd94253f 100644 --- a/agent/subscribe/subscribe_test.go +++ b/agent/subscribe/subscribe_test.go @@ -19,6 +19,7 @@ import ( "github.com/hashicorp/consul/agent/consul/stream" "github.com/hashicorp/consul/agent/grpc" "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/proto/pbcommon" "github.com/hashicorp/consul/proto/pbservice" "github.com/hashicorp/consul/proto/pbsubscribe" @@ -29,7 +30,6 @@ func TestServer_Subscribe_IntegrationWithBackend(t *testing.T) { backend, err := newTestBackend() require.NoError(t, err) srv := &Server{Backend: backend, Logger: hclog.New(nil)} - addr := newTestServer(t, srv) ids := newCounter() @@ -273,12 +273,12 @@ func assertDeepEqual(t *testing.T, x, y interface{}) { type testBackend struct { store *state.Store - authorizer acl.Authorizer + authorizer func(token string) acl.Authorizer forwardConn *gogrpc.ClientConn } -func (b testBackend) ResolveToken(_ string) (acl.Authorizer, error) { - return b.authorizer, nil +func (b testBackend) ResolveToken(token string) (acl.Authorizer, error) { + return b.authorizer(token), nil } func (b testBackend) Forward(_ string, fn func(*gogrpc.ClientConn) error) (handled bool, err error) { @@ -301,7 +301,10 @@ func newTestBackend() (*testBackend, error) { if err != nil { return nil, err } - return &testBackend{store: store, authorizer: acl.AllowAll()}, nil + allowAll := func(_ string) acl.Authorizer { + return acl.AllowAll() + } + return &testBackend{store: store, authorizer: allowAll}, nil } var _ Backend = (*testBackend)(nil) @@ -395,7 +398,7 @@ func TestServer_Subscribe_IntegrationWithBackend_ForwardToDC(t *testing.T) { Port: 9000, }, } - require.NoError(t, backendRemoteDC.store.EnsureRegistration(ids.Next("req1"), req)) + require.NoError(t, backendRemoteDC.store.EnsureRegistration(ids.Next("reg1"), req)) } { req := &structs.RegisterRequest{ @@ -577,236 +580,128 @@ func TestServer_Subscribe_IntegrationWithBackend_ForwardToDC(t *testing.T) { assertDeepEqual(t, expectedEvent, event) } -/* TODO -func TestStreaming_Subscribe_SkipSnapshot(t *testing.T) { - t.Parallel() +// TODO: test case for converting stream.Events to pbsubscribe.Events, including framing events - require := require.New(t) - dir1, server := testServerWithConfig(t, func(c *Config) { - c.Datacenter = "dc1" - c.Bootstrap = true - c.GRPCEnabled = true - }) - defer os.RemoveAll(dir1) - defer server.Shutdown() - codec := rpcClient(t, server) - defer codec.Close() - - dir2, client := testClientWithConfig(t, func(c *Config) { - c.Datacenter = "dc1" - c.NodeName = uniqueNodeName(t.Name()) - c.GRPCEnabled = true - }) - defer os.RemoveAll(dir2) - defer client.Shutdown() - - // Try to join - testrpc.WaitForLeader(t, server.RPC, "dc1") - joinLAN(t, client, server) - testrpc.WaitForTestAgent(t, client.RPC, "dc1") - - // Register a dummy node with our service on it. - { - req := &structs.RegisterRequest{ - Node: "node1", - Address: "3.4.5.6", - Datacenter: "dc1", - Service: &structs.NodeService{ - ID: "redis1", - Service: "redis", - Address: "3.4.5.6", - Port: 8080, - }, - } - var out struct{} - require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) +func TestServer_Subscribe_IntegrationWithBackend_FilterEventsByACLToken(t *testing.T) { + if testing.Short() { + t.Skip("too slow for -short run") } - // Start a Subscribe call to our streaming endpoint. - conn, err := client.GRPCConn() - require.NoError(err) - - streamClient := pbsubscribe.NewConsulClient(conn) - - var index uint64 - { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis"}) - require.NoError(err) - - // Start a goroutine to read updates off the pbsubscribe. - eventCh := make(chan *pbsubscribe.Event, 0) - go recvEvents(t, eventCh, streamHandle) - - var snapshotEvents []*pbsubscribe.Event - for i := 0; i < 2; i++ { - select { - case event := <-eventCh: - snapshotEvents = append(snapshotEvents, event) - case <-time.After(3 * time.Second): - t.Fatalf("did not receive events past %d", len(snapshotEvents)) - } - } - - // Save the index from the event - index = snapshotEvents[0].Index - } - - // Start another Subscribe call passing the index from the last event. - { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{ - Topic: pbsubscribe.Topic_ServiceHealth, - Key: "redis", - Index: index, - }) - require.NoError(err) - - // Start a goroutine to read updates off the pbsubscribe. - eventCh := make(chan *pbsubscribe.Event, 0) - go recvEvents(t, eventCh, streamHandle) - - // We should get no snapshot and the first event should be "resume stream" - select { - case event := <-eventCh: - require.True(event.GetResumeStream()) - case <-time.After(500 * time.Millisecond): - t.Fatalf("never got event") - } - - // Wait and make sure there aren't any events coming. The server shouldn't send - // a snapshot and we haven't made any updates to the catalog that would trigger - // more events. - select { - case event := <-eventCh: - t.Fatalf("got another event: %v", event) - case <-time.After(500 * time.Millisecond): - } - } -} - -func TestStreaming_Subscribe_FilterACL(t *testing.T) { - t.Parallel() - - require := require.New(t) - dir, _, server, codec := testACLFilterServerWithConfigFn(t, func(c *Config) { - c.ACLDatacenter = "dc1" - c.ACLsEnabled = true - c.ACLMasterToken = "root" - c.ACLDefaultPolicy = "deny" - c.ACLEnforceVersion8 = true - c.GRPCEnabled = true - }) - defer os.RemoveAll(dir) - defer server.Shutdown() - defer codec.Close() - - dir2, client := testClientWithConfig(t, func(c *Config) { - c.Datacenter = "dc1" - c.NodeName = uniqueNodeName(t.Name()) - c.GRPCEnabled = true - }) - defer os.RemoveAll(dir2) - defer client.Shutdown() - - // Try to join - testrpc.WaitForLeader(t, server.RPC, "dc1") - joinLAN(t, client, server) - testrpc.WaitForTestAgent(t, client.RPC, "dc1", testrpc.WithToken("root")) + backend, err := newTestBackend() + require.NoError(t, err) + srv := &Server{Backend: backend, Logger: hclog.New(nil)} + addr := newTestServer(t, srv) // Create a policy for the test token. - policyReq := structs.ACLPolicySetRequest{ - Datacenter: "dc1", - Policy: structs.ACLPolicy{ - Description: "foobar", - Name: "baz", - Rules: fmt.Sprintf(` - service "foo" { - policy = "write" - } - node "%s" { - policy = "write" - } - `, server.config.NodeName), - }, - WriteRequest: structs.WriteRequest{Token: "root"}, - } - resp := structs.ACLPolicy{} - require.NoError(msgpackrpc.CallWithCodec(codec, "ACL.PolicySet", &policyReq, &resp)) + rules := ` +service "foo" { + policy = "write" +} +node "node1" { + policy = "write" +} +` + authorizer, err := acl.NewAuthorizerFromRules( + "1", 0, rules, acl.SyntaxCurrent, + &acl.Config{WildcardName: structs.WildcardSpecifier}, + nil) + require.NoError(t, err) + authorizer = acl.NewChainedAuthorizer([]acl.Authorizer{authorizer, acl.DenyAll()}) + require.Equal(t, acl.Deny, authorizer.NodeRead("denied", nil)) - // Create a new token that only has access to one node. - var token structs.ACLToken - arg := structs.ACLTokenSetRequest{ - Datacenter: "dc1", - ACLToken: structs.ACLToken{ - Policies: []structs.ACLTokenPolicyLink{ - structs.ACLTokenPolicyLink{ - ID: resp.ID, - }, + // TODO: is there any easy way to do this with the acl package? + token := "this-token-is-good" + backend.authorizer = func(tok string) acl.Authorizer { + if tok == token { + return authorizer + } + return acl.DenyAll() + } + + ids := newCounter() + { + req := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "node1", + Address: "127.0.0.1", + Service: &structs.NodeService{ + ID: "foo", + Service: "foo", }, - }, - WriteRequest: structs.WriteRequest{Token: "root"}, - } - require.NoError(msgpackrpc.CallWithCodec(codec, "ACL.TokenSet", &arg, &token)) - auth, err := server.ResolveToken(token.SecretID) - require.NoError(err) - require.Equal(auth.NodeRead("denied", nil), acl.Deny) + Check: &structs.HealthCheck{ + CheckID: "service:foo", + Name: "service:foo", + Node: "node1", + ServiceID: "foo", + Status: api.HealthPassing, + }, + } + require.NoError(t, backend.store.EnsureRegistration(ids.Next("reg1"), req)) - // Register another instance of service foo on a fake node the token doesn't have access to. - regArg := structs.RegisterRequest{ - Datacenter: "dc1", - Node: "denied", - Address: "127.0.0.1", - Service: &structs.NodeService{ - ID: "foo", - Service: "foo", - }, - WriteRequest: structs.WriteRequest{Token: "root"}, - } - require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", ®Arg, nil)) + // Register a service which should be denied + req = &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "node1", + Address: "127.0.0.1", + Service: &structs.NodeService{ + ID: "bar", + Service: "bar", + }, + Check: &structs.HealthCheck{ + CheckID: "service:bar", + Name: "service:bar", + Node: "node1", + ServiceID: "bar", + }, + } + require.NoError(t, backend.store.EnsureRegistration(ids.Next("reg2"), req)) - // Set up the gRPC client. - conn, err := client.GRPCConn() - require.NoError(err) - streamClient := pbsubscribe.NewConsulClient(conn) + req = &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "denied", + Address: "127.0.0.1", + Service: &structs.NodeService{ + ID: "foo", + Service: "foo", + }, + } + require.NoError(t, backend.store.EnsureRegistration(ids.Next("reg3"), req)) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) + + conn, err := gogrpc.DialContext(ctx, addr.String(), gogrpc.WithInsecure()) + require.NoError(t, err) + t.Cleanup(logError(t, conn.Close)) + streamClient := pbsubscribe.NewStateChangeSubscriptionClient(conn) // Start a Subscribe call to our streaming endpoint for the service we have access to. { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{ Topic: pbsubscribe.Topic_ServiceHealth, Key: "foo", - Token: token.SecretID, + Token: token, }) - require.NoError(err) + require.NoError(t, err) - // Start a goroutine to read updates off the pbsubscribe. - eventCh := make(chan *pbsubscribe.Event, 0) - go recvEvents(t, eventCh, streamHandle) + chEvents := make(chan eventOrError, 0) + go recvEvents(chEvents, streamHandle) - // Read events off the pbsubscribe. We should not see any events for the filtered node. var snapshotEvents []*pbsubscribe.Event for i := 0; i < 2; i++ { - select { - case event := <-eventCh: - snapshotEvents = append(snapshotEvents, event) - case <-time.After(5 * time.Second): - t.Fatalf("did not receive events past %d", len(snapshotEvents)) - } + snapshotEvents = append(snapshotEvents, getEvent(t, chEvents)) } - require.Len(snapshotEvents, 2) - require.Equal("foo", snapshotEvents[0].GetServiceHealth().CheckServiceNode.Service.Service) - require.Equal(server.config.NodeName, snapshotEvents[0].GetServiceHealth().CheckServiceNode.Node.Node) - require.True(snapshotEvents[1].GetEndOfSnapshot()) + + require.Len(t, snapshotEvents, 2) + require.Equal(t, "foo", snapshotEvents[0].GetServiceHealth().CheckServiceNode.Service.Service) + require.Equal(t, "node1", snapshotEvents[0].GetServiceHealth().CheckServiceNode.Node.Node) + require.True(t, snapshotEvents[1].GetEndOfSnapshot()) // Update the service with a new port to trigger a new event. - regArg := structs.RegisterRequest{ + req := &structs.RegisterRequest{ Datacenter: "dc1", - Node: server.config.NodeName, + Node: "node1", Address: "127.0.0.1", Service: &structs.NodeService{ ID: "foo", @@ -818,22 +713,19 @@ func TestStreaming_Subscribe_FilterACL(t *testing.T) { Name: "service:foo", ServiceID: "foo", Status: api.HealthPassing, + Node: "node1", }, WriteRequest: structs.WriteRequest{Token: "root"}, } - require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", ®Arg, nil)) + require.NoError(t, backend.store.EnsureRegistration(ids.Next("reg4"), req)) - select { - case event := <-eventCh: - service := event.GetServiceHealth().CheckServiceNode.Service - require.Equal("foo", service.Service) - require.Equal(1234, service.Port) - case <-time.After(5 * time.Second): - t.Fatalf("did not receive events past %d", len(snapshotEvents)) - } + event := getEvent(t, chEvents) + service := event.GetServiceHealth().CheckServiceNode.Service + require.Equal(t, "foo", service.Service) + require.Equal(t, int32(1234), service.Port) // Now update the service on the denied node and make sure we don't see an event. - regArg = structs.RegisterRequest{ + req = &structs.RegisterRequest{ Datacenter: "dc1", Node: "denied", Address: "127.0.0.1", @@ -847,13 +739,14 @@ func TestStreaming_Subscribe_FilterACL(t *testing.T) { Name: "service:foo", ServiceID: "foo", Status: api.HealthPassing, + Node: "denied", }, WriteRequest: structs.WriteRequest{Token: "root"}, } - require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", ®Arg, nil)) + require.NoError(t, backend.store.EnsureRegistration(ids.Next("reg5"), req)) select { - case event := <-eventCh: + case event := <-chEvents: t.Fatalf("should not have received event: %v", event) case <-time.After(500 * time.Millisecond): } @@ -861,30 +754,22 @@ func TestStreaming_Subscribe_FilterACL(t *testing.T) { // Start another subscribe call for bar, which the token shouldn't have access to. { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{ Topic: pbsubscribe.Topic_ServiceHealth, Key: "bar", - Token: token.SecretID, + Token: token, }) - require.NoError(err) + require.NoError(t, err) - // Start a goroutine to read updates off the pbsubscribe. - eventCh := make(chan *pbsubscribe.Event, 0) - go recvEvents(t, eventCh, streamHandle) + chEvents := make(chan eventOrError, 0) + go recvEvents(chEvents, streamHandle) - select { - case event := <-eventCh: - require.True(event.GetEndOfSnapshot()) - case <-time.After(3 * time.Second): - t.Fatal("did not receive event") - } + require.True(t, getEvent(t, chEvents).GetEndOfSnapshot()) // Update the service and make sure we don't get a new event. - regArg := structs.RegisterRequest{ + req := &structs.RegisterRequest{ Datacenter: "dc1", - Node: server.config.NodeName, + Node: "node1", Address: "127.0.0.1", Service: &structs.NodeService{ ID: "bar", @@ -895,19 +780,21 @@ func TestStreaming_Subscribe_FilterACL(t *testing.T) { CheckID: "service:bar", Name: "service:bar", ServiceID: "bar", + Node: "node1", }, WriteRequest: structs.WriteRequest{Token: "root"}, } - require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", ®Arg, nil)) + require.NoError(t, backend.store.EnsureRegistration(ids.Next("reg6"), req)) select { - case event := <-eventCh: + case event := <-chEvents: t.Fatalf("should not have received event: %v", event) case <-time.After(500 * time.Millisecond): } } } +/* func TestStreaming_Subscribe_ACLUpdate(t *testing.T) { t.Parallel() From 9e4ebacb058aee6587aa0001ae8c686932cf1ef4 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Mon, 28 Sep 2020 18:17:57 -0400 Subject: [PATCH 07/11] subscribe: add integration test for acl token updates --- agent/consul/state/acl.go | 3 +- agent/subscribe/subscribe.go | 1 - agent/subscribe/subscribe_test.go | 611 ++++-------------------------- 3 files changed, 72 insertions(+), 543 deletions(-) diff --git a/agent/consul/state/acl.go b/agent/consul/state/acl.go index a6e516111b..1613e753aa 100644 --- a/agent/consul/state/acl.go +++ b/agent/consul/state/acl.go @@ -5,9 +5,10 @@ import ( "fmt" "time" + memdb "github.com/hashicorp/go-memdb" + "github.com/hashicorp/consul/agent/structs" pbacl "github.com/hashicorp/consul/proto/pbacl" - memdb "github.com/hashicorp/go-memdb" ) type TokenPoliciesIndex struct { diff --git a/agent/subscribe/subscribe.go b/agent/subscribe/subscribe.go index 7908a410b3..0b1d0a6a94 100644 --- a/agent/subscribe/subscribe.go +++ b/agent/subscribe/subscribe.go @@ -88,7 +88,6 @@ func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsub for { events, err := sub.Next(ctx) switch { - // TODO: test case case errors.Is(err, stream.ErrSubscriptionClosed): h.Logger.Trace("subscription reset by server", "stream_id", streamID) return status.Error(codes.Aborted, err.Error()) diff --git a/agent/subscribe/subscribe_test.go b/agent/subscribe/subscribe_test.go index bacd94253f..a005f6eee1 100644 --- a/agent/subscribe/subscribe_test.go +++ b/agent/subscribe/subscribe_test.go @@ -10,9 +10,12 @@ import ( "github.com/google/go-cmp/cmp" "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-uuid" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" gogrpc "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/consul/state" @@ -688,15 +691,11 @@ node "node1" { chEvents := make(chan eventOrError, 0) go recvEvents(chEvents, streamHandle) - var snapshotEvents []*pbsubscribe.Event - for i := 0; i < 2; i++ { - snapshotEvents = append(snapshotEvents, getEvent(t, chEvents)) - } + event := getEvent(t, chEvents) + require.Equal(t, "foo", event.GetServiceHealth().CheckServiceNode.Service.Service) + require.Equal(t, "node1", event.GetServiceHealth().CheckServiceNode.Node.Node) - require.Len(t, snapshotEvents, 2) - require.Equal(t, "foo", snapshotEvents[0].GetServiceHealth().CheckServiceNode.Service.Service) - require.Equal(t, "node1", snapshotEvents[0].GetServiceHealth().CheckServiceNode.Node.Node) - require.True(t, snapshotEvents[1].GetEndOfSnapshot()) + require.True(t, getEvent(t, chEvents).GetEndOfSnapshot()) // Update the service with a new port to trigger a new event. req := &structs.RegisterRequest{ @@ -719,7 +718,7 @@ node "node1" { } require.NoError(t, backend.store.EnsureRegistration(ids.Next("reg4"), req)) - event := getEvent(t, chEvents) + event = getEvent(t, chEvents) service := event.GetServiceHealth().CheckServiceNode.Service require.Equal(t, "foo", service.Service) require.Equal(t, int32(1234), service.Port) @@ -794,549 +793,79 @@ node "node1" { } } -/* -func TestStreaming_Subscribe_ACLUpdate(t *testing.T) { - t.Parallel() +func TestServer_Subscribe_IntegrationWithBackend_ACLUpdate(t *testing.T) { + backend, err := newTestBackend() + require.NoError(t, err) + srv := &Server{Backend: backend, Logger: hclog.New(nil)} + addr := newTestServer(t, srv) - require := require.New(t) - dir, _, server, codec := testACLFilterServerWithConfigFn(t, func(c *Config) { - c.ACLDatacenter = "dc1" - c.ACLsEnabled = true - c.ACLMasterToken = "root" - c.ACLDefaultPolicy = "deny" - c.ACLEnforceVersion8 = true - c.GRPCEnabled = true + rules := ` +service "foo" { + policy = "write" +} +node "node1" { + policy = "write" +} +` + authorizer, err := acl.NewAuthorizerFromRules( + "1", 0, rules, acl.SyntaxCurrent, + &acl.Config{WildcardName: structs.WildcardSpecifier}, + nil) + require.NoError(t, err) + authorizer = acl.NewChainedAuthorizer([]acl.Authorizer{authorizer, acl.DenyAll()}) + require.Equal(t, acl.Deny, authorizer.NodeRead("denied", nil)) + + // TODO: is there any easy way to do this with the acl package? + token := "this-token-is-good" + backend.authorizer = func(tok string) acl.Authorizer { + if tok == token { + return authorizer + } + return acl.DenyAll() + } + + ids := newCounter() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) + + conn, err := gogrpc.DialContext(ctx, addr.String(), gogrpc.WithInsecure()) + require.NoError(t, err) + t.Cleanup(logError(t, conn.Close)) + streamClient := pbsubscribe.NewStateChangeSubscriptionClient(conn) + + streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{ + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "foo", + Token: token, }) - defer os.RemoveAll(dir) - defer server.Shutdown() - defer codec.Close() + require.NoError(t, err) - dir2, client := testClientWithConfig(t, func(c *Config) { - c.Datacenter = "dc1" - c.NodeName = uniqueNodeName(t.Name()) - c.GRPCEnabled = true - }) - defer os.RemoveAll(dir2) - defer client.Shutdown() + chEvents := make(chan eventOrError, 0) + go recvEvents(chEvents, streamHandle) - // Try to join - testrpc.WaitForLeader(t, server.RPC, "dc1") - joinLAN(t, client, server) - testrpc.WaitForTestAgent(t, client.RPC, "dc1", testrpc.WithToken("root")) + require.True(t, getEvent(t, chEvents).GetEndOfSnapshot()) - // Create a new token/policy that only has access to one node. - var token structs.ACLToken + tokenID, err := uuid.GenerateUUID() + require.NoError(t, err) - policy, err := upsertTestPolicyWithRules(codec, "root", "dc1", fmt.Sprintf(` - service "foo" { - policy = "write" - } - node "%s" { - policy = "write" - } - `, server.config.NodeName)) - require.NoError(err) - - arg := structs.ACLTokenSetRequest{ - Datacenter: "dc1", - ACLToken: structs.ACLToken{ - Description: "Service/node token", - Policies: []structs.ACLTokenPolicyLink{ - structs.ACLTokenPolicyLink{ - ID: policy.ID, - }, - }, - Local: false, - }, - WriteRequest: structs.WriteRequest{Token: "root"}, + aclToken := &structs.ACLToken{ + AccessorID: tokenID, + SecretID: token, + Rules: "", } + require.NoError(t, backend.store.ACLTokenSet(ids.Next("update"), aclToken, false)) - require.NoError(msgpackrpc.CallWithCodec(codec, "ACL.TokenSet", &arg, &token)) - auth, err := server.ResolveToken(token.SecretID) - require.NoError(err) - require.Equal(auth.NodeRead("denied", nil), acl.Deny) - - // Set up the gRPC client. - conn, err := client.GRPCConn() - require.NoError(err) - streamClient := pbsubscribe.NewConsulClient(conn) - - // Start a Subscribe call to our streaming endpoint for the service we have access to. - { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - - streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{ - Topic: pbsubscribe.Topic_ServiceHealth, - Key: "foo", - Token: token.SecretID, - }) - require.NoError(err) - - // Start a goroutine to read updates off the pbsubscribe. - eventCh := make(chan *pbsubscribe.Event, 0) - go recvEvents(t, eventCh, streamHandle) - - // Read events off the pbsubscribe. - var snapshotEvents []*pbsubscribe.Event - for i := 0; i < 2; i++ { - select { - case event := <-eventCh: - snapshotEvents = append(snapshotEvents, event) - case <-time.After(5 * time.Second): - t.Fatalf("did not receive events past %d", len(snapshotEvents)) - } - } - require.Len(snapshotEvents, 2) - require.Equal("foo", snapshotEvents[0].GetServiceHealth().CheckServiceNode.Service.Service) - require.Equal(server.config.NodeName, snapshotEvents[0].GetServiceHealth().CheckServiceNode.Node.Node) - require.True(snapshotEvents[1].GetEndOfSnapshot()) - - // Update a different token and make sure we don't see an event. - arg2 := structs.ACLTokenSetRequest{ - Datacenter: "dc1", - ACLToken: structs.ACLToken{ - Description: "Ignored token", - Policies: []structs.ACLTokenPolicyLink{ - structs.ACLTokenPolicyLink{ - ID: policy.ID, - }, - }, - Local: false, - }, - WriteRequest: structs.WriteRequest{Token: "root"}, - } - var ignoredToken structs.ACLToken - require.NoError(msgpackrpc.CallWithCodec(codec, "ACL.TokenSet", &arg2, &ignoredToken)) - - select { - case event := <-eventCh: - t.Fatalf("should not have received event: %v", event) - case <-time.After(500 * time.Millisecond): - } - - // Update our token to trigger a refresh event. - token.Policies = []structs.ACLTokenPolicyLink{} - arg := structs.ACLTokenSetRequest{ - Datacenter: "dc1", - ACLToken: token, - WriteRequest: structs.WriteRequest{Token: "root"}, - } - require.NoError(msgpackrpc.CallWithCodec(codec, "ACL.TokenSet", &arg, &token)) - - select { - case event := <-eventCh: - require.True(event.GetResetStream()) - // 500 ms was not enough in CI apparently... - case <-time.After(2 * time.Second): - t.Fatalf("did not receive reload event") - } + select { + case item := <-chEvents: + require.Error(t, item.err, "got event: %v", item.event) + s, _ := status.FromError(item.err) + require.Equal(t, codes.Aborted, s.Code()) + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for aborted error") } } -func TestStreaming_TLSEnabled(t *testing.T) { - t.Parallel() - - require := require.New(t) - dir1, conf1 := testServerConfig(t) - conf1.VerifyIncoming = true - conf1.VerifyOutgoing = true - conf1.GRPCEnabled = true - configureTLS(conf1) - server, err := newServer(conf1) - if err != nil { - t.Fatalf("err: %v", err) - } - defer os.RemoveAll(dir1) - defer server.Shutdown() - - dir2, conf2 := testClientConfig(t) - conf2.VerifyOutgoing = true - conf2.GRPCEnabled = true - configureTLS(conf2) - client, err := NewClient(conf2) - if err != nil { - t.Fatalf("err: %v", err) - } - defer os.RemoveAll(dir2) - defer client.Shutdown() - - // Try to join - testrpc.WaitForLeader(t, server.RPC, "dc1") - joinLAN(t, client, server) - testrpc.WaitForTestAgent(t, client.RPC, "dc1") - - // Register a dummy node with our service on it. - { - req := &structs.RegisterRequest{ - Node: "node1", - Address: "3.4.5.6", - Datacenter: "dc1", - Service: &structs.NodeService{ - ID: "redis1", - Service: "redis", - Address: "3.4.5.6", - Port: 8080, - }, - } - var out struct{} - require.NoError(server.RPC("Catalog.Register", &req, &out)) - } - - // Start a Subscribe call to our streaming endpoint from the client. - { - conn, err := client.GRPCConn() - require.NoError(err) - - streamClient := pbsubscribe.NewConsulClient(conn) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis"}) - require.NoError(err) - - // Start a goroutine to read updates off the pbsubscribe. - eventCh := make(chan *pbsubscribe.Event, 0) - go recvEvents(t, eventCh, streamHandle) - - var snapshotEvents []*pbsubscribe.Event - for i := 0; i < 2; i++ { - select { - case event := <-eventCh: - snapshotEvents = append(snapshotEvents, event) - case <-time.After(3 * time.Second): - t.Fatalf("did not receive events past %d", len(snapshotEvents)) - } - } - - // Make sure the snapshot events come back with no issues. - require.Len(snapshotEvents, 2) - } - - // Start a Subscribe call to our streaming endpoint from the server's loopback client. - { - conn, err := server.GRPCConn() - require.NoError(err) - - retryFailedConn(t, conn) - - streamClient := pbsubscribe.NewConsulClient(conn) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis"}) - require.NoError(err) - - // Start a goroutine to read updates off the pbsubscribe. - eventCh := make(chan *pbsubscribe.Event, 0) - go recvEvents(t, eventCh, streamHandle) - - var snapshotEvents []*pbsubscribe.Event - for i := 0; i < 2; i++ { - select { - case event := <-eventCh: - snapshotEvents = append(snapshotEvents, event) - case <-time.After(3 * time.Second): - t.Fatalf("did not receive events past %d", len(snapshotEvents)) - } - } - - // Make sure the snapshot events come back with no issues. - require.Len(snapshotEvents, 2) - } -} - -func TestStreaming_TLSReload(t *testing.T) { - t.Parallel() - - // Set up a server with initially bad certificates. - require := require.New(t) - dir1, conf1 := testServerConfig(t) - conf1.VerifyIncoming = true - conf1.VerifyOutgoing = true - conf1.CAFile = "../../test/ca/root.cer" - conf1.CertFile = "../../test/key/ssl-cert-snakeoil.pem" - conf1.KeyFile = "../../test/key/ssl-cert-snakeoil.key" - conf1.GRPCEnabled = true - - server, err := newServer(conf1) - if err != nil { - t.Fatalf("err: %v", err) - } - defer os.RemoveAll(dir1) - defer server.Shutdown() - - // Set up a client with valid certs and verify_outgoing = true - dir2, conf2 := testClientConfig(t) - conf2.VerifyOutgoing = true - conf2.GRPCEnabled = true - configureTLS(conf2) - client, err := NewClient(conf2) - if err != nil { - t.Fatalf("err: %v", err) - } - defer os.RemoveAll(dir2) - defer client.Shutdown() - - testrpc.WaitForLeader(t, server.RPC, "dc1") - - // Subscribe calls should fail initially - joinLAN(t, client, server) - conn, err := client.GRPCConn() - require.NoError(err) - { - streamClient := pbsubscribe.NewConsulClient(conn) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - _, err = streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis"}) - require.Error(err, "tls: bad certificate") - } - - // Reload the server with valid certs - newConf := server.config.ToTLSUtilConfig() - newConf.CertFile = "../../test/key/ourdomain.cer" - newConf.KeyFile = "../../test/key/ourdomain.key" - server.tlsConfigurator.Update(newConf) - - // Try the subscribe call again - { - retryFailedConn(t, conn) - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - streamClient := pbsubscribe.NewConsulClient(conn) - _, err = streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis"}) - require.NoError(err) - } -} - -// retryFailedConn forces the ClientConn to reset its backoff timer and retry the connection, -// to simulate the client eventually retrying after the initial failure. This is used both to simulate -// retrying after an expected failure as well as to avoid flakiness when running many tests in parallel. -func retryFailedConn(t *testing.T, conn *grpc.ClientConn) { - state := conn.GetState() - if state.String() != "TRANSIENT_FAILURE" { - return - } - - // If the connection has failed, retry and wait for a state change. - conn.ResetConnectBackoff() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - require.True(t, conn.WaitForStateChange(ctx, state)) -} - -func TestStreaming_DeliversAllMessages(t *testing.T) { - // This is a fuzz/probabilistic test to try to provoke streaming into dropping - // messages. There is a bug in the initial implementation that should make - // this fail. While we can't be certain a pass means it's correct, it is - // useful for finding bugs in our concurrency design. - - // The issue is that when updates are coming in fast such that updates occur - // in between us making the snapshot and beginning the stream updates, we - // shouldn't miss anything. - - // To test this, we will run a background goroutine that will write updates as - // fast as possible while we then try to stream the results and ensure that we - // see every change. We'll make the updates monotonically increasing so we can - // easily tell if we missed one. - - require := require.New(t) - dir1, server := testServerWithConfig(t, func(c *Config) { - c.Datacenter = "dc1" - c.Bootstrap = true - c.GRPCEnabled = true - }) - defer os.RemoveAll(dir1) - defer server.Shutdown() - codec := rpcClient(t, server) - defer codec.Close() - - dir2, client := testClientWithConfig(t, func(c *Config) { - c.Datacenter = "dc1" - c.NodeName = uniqueNodeName(t.Name()) - c.GRPCEnabled = true - }) - defer os.RemoveAll(dir2) - defer client.Shutdown() - - // Try to join - testrpc.WaitForLeader(t, server.RPC, "dc1") - joinLAN(t, client, server) - testrpc.WaitForTestAgent(t, client.RPC, "dc1") - - // Register a whole bunch of service instances so that the initial snapshot on - // subscribe is big enough to take a bit of time to load giving more - // opportunity for missed updates if there is a bug. - for i := 0; i < 1000; i++ { - req := &structs.RegisterRequest{ - Node: fmt.Sprintf("node-redis-%03d", i), - Address: "3.4.5.6", - Datacenter: "dc1", - Service: &structs.NodeService{ - ID: fmt.Sprintf("redis-%03d", i), - Service: "redis", - Port: 11211, - }, - } - var out struct{} - require.NoError(server.RPC("Catalog.Register", &req, &out)) - } - - // Start background writer - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - go func() { - // Update the registration with a monotonically increasing port as fast as - // we can. - req := &structs.RegisterRequest{ - Node: "node1", - Address: "3.4.5.6", - Datacenter: "dc1", - Service: &structs.NodeService{ - ID: "redis-canary", - Service: "redis", - Port: 0, - }, - } - for { - if ctx.Err() != nil { - return - } - var out struct{} - require.NoError(server.RPC("Catalog.Register", &req, &out)) - req.Service.Port++ - if req.Service.Port > 100 { - return - } - time.Sleep(1 * time.Millisecond) - } - }() - - // Now start a whole bunch of streamers in parallel to maximise chance of - // catching a race. - conn, err := client.GRPCConn() - require.NoError(err) - - streamClient := pbsubscribe.NewConsulClient(conn) - - n := 5 - var wg sync.WaitGroup - var updateCount uint64 - // Buffered error chan so that workers can exit and terminate wg without - // blocking on send. We collect errors this way since t isn't thread safe. - errCh := make(chan error, n) - for i := 0; i < n; i++ { - wg.Add(1) - go verifyMonotonicStreamUpdates(ctx, t, streamClient, &wg, i, &updateCount, errCh) - } - - // Wait until all subscribers have verified the first bunch of updates all got - // delivered. - wg.Wait() - - close(errCh) - - // Require that none of them errored. Since we closed the chan above this loop - // should terminate immediately if no errors were buffered. - for err := range errCh { - require.NoError(err) - } - - // Sanity check that at least some non-snapshot messages were delivered. We - // can't know exactly how many because it's timing dependent based on when - // each subscribers snapshot occurs. - require.True(atomic.LoadUint64(&updateCount) > 0, - "at least some of the subscribers should have received non-snapshot updates") -} - -type testLogger interface { - Logf(format string, args ...interface{}) -} - -func verifyMonotonicStreamUpdates(ctx context.Context, logger testLogger, client pbsubscribe.StateChangeSubscriptionClient, wg *sync.WaitGroup, i int, updateCount *uint64, errCh chan<- error) { - defer wg.Done() - streamHandle, err := client.Subscribe(ctx, &pbsubscribe.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis"}) - if err != nil { - if strings.Contains(err.Error(), "context deadline exceeded") || - strings.Contains(err.Error(), "context canceled") { - logger.Logf("subscriber %05d: context cancelled before loop") - return - } - errCh <- err - return - } - - snapshotDone := false - expectPort := 0 - for { - event, err := streamHandle.Recv() - if err == io.EOF { - break - } - if err != nil { - if strings.Contains(err.Error(), "context deadline exceeded") || - strings.Contains(err.Error(), "context canceled") { - break - } - errCh <- err - return - } - - // Ignore snapshot message - if event.GetEndOfSnapshot() || event.GetResumeStream() { - snapshotDone = true - logger.Logf("subscriber %05d: snapshot done, expect next port to be %d", i, expectPort) - } else if snapshotDone { - // Verify we get all updates in order - svc, err := svcOrErr(event) - if err != nil { - errCh <- err - return - } - if expectPort != svc.Port { - errCh <- fmt.Errorf("subscriber %05d: missed %d update(s)!", i, svc.Port-expectPort) - return - } - atomic.AddUint64(updateCount, 1) - logger.Logf("subscriber %05d: got event with correct port=%d", i, expectPort) - expectPort++ - } else { - // This is a snapshot update. Check if it's an update for the canary - // instance that got applied before our snapshot was sent (likely) - svc, err := svcOrErr(event) - if err != nil { - errCh <- err - return - } - if svc.ID == "redis-canary" { - // Update the expected port we see in the next update to be one more - // than the port in the snapshot. - expectPort = svc.Port + 1 - logger.Logf("subscriber %05d: saw canary in snapshot with port %d", i, svc.Port) - } - } - if expectPort > 100 { - return - } - } -} - -func svcOrErr(event *pbsubscribe.Event) (*pbservice.NodeService, error) { - health := event.GetServiceHealth() - if health == nil { - return nil, fmt.Errorf("not a health event: %#v", event) - } - csn := health.CheckServiceNode - if csn == nil { - return nil, fmt.Errorf("nil CSN: %#v", event) - } - if csn.Service == nil { - return nil, fmt.Errorf("nil service: %#v", event) - } - return csn.Service, nil -} -*/ - func logError(t *testing.T, f func() error) func() { return func() { if err := f(); err != nil { From dbb8bd679f1209baec2711f0a17544e41b9bde94 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Mon, 28 Sep 2020 18:52:31 -0400 Subject: [PATCH 08/11] subscirbe: extract streamID and logging from Subscribe By extracting all of the tracing logic the core logic of the Subscribe endpoint is much easier to read. --- agent/consul/server.go | 26 ++++----- agent/subscribe/logger.go | 71 +++++++++++++++++++++++++ agent/subscribe/subscribe.go | 87 +++++++------------------------ agent/subscribe/subscribe_test.go | 10 ++-- 4 files changed, 109 insertions(+), 85 deletions(-) create mode 100644 agent/subscribe/logger.go diff --git a/agent/consul/server.go b/agent/consul/server.go index a8bbfde217..dbfa5b4614 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -18,6 +18,16 @@ import ( "time" metrics "github.com/armon/go-metrics" + connlimit "github.com/hashicorp/go-connlimit" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-memdb" + "github.com/hashicorp/memberlist" + "github.com/hashicorp/raft" + raftboltdb "github.com/hashicorp/raft-boltdb" + "github.com/hashicorp/serf/serf" + "golang.org/x/time/rate" + "google.golang.org/grpc" + "github.com/hashicorp/consul/acl" ca "github.com/hashicorp/consul/agent/connect/ca" "github.com/hashicorp/consul/agent/consul/authmethod" @@ -38,15 +48,6 @@ import ( "github.com/hashicorp/consul/proto/pbsubscribe" "github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/consul/types" - connlimit "github.com/hashicorp/go-connlimit" - "github.com/hashicorp/go-hclog" - "github.com/hashicorp/go-memdb" - "github.com/hashicorp/memberlist" - "github.com/hashicorp/raft" - raftboltdb "github.com/hashicorp/raft-boltdb" - "github.com/hashicorp/serf/serf" - "golang.org/x/time/rate" - "google.golang.org/grpc" ) // These are the protocol versions that Consul can _understand_. These are @@ -615,10 +616,9 @@ func newGRPCHandlerFromConfig(deps Deps, config *Config, s *Server) connHandler } register := func(srv *grpc.Server) { - pbsubscribe.RegisterStateChangeSubscriptionServer(srv, &subscribe.Server{ - Backend: &subscribeBackend{srv: s, connPool: deps.GRPCConnPool}, - Logger: deps.Logger.Named("grpc-api.subscription"), - }) + pbsubscribe.RegisterStateChangeSubscriptionServer(srv, subscribe.NewServer( + &subscribeBackend{srv: s, connPool: deps.GRPCConnPool}, + deps.Logger.Named("grpc-api.subscription"))) } return agentgrpc.NewHandler(config.RPCAddr, register) } diff --git a/agent/subscribe/logger.go b/agent/subscribe/logger.go new file mode 100644 index 0000000000..b1a32a6cdb --- /dev/null +++ b/agent/subscribe/logger.go @@ -0,0 +1,71 @@ +package subscribe + +import ( + "sync" + "time" + + "github.com/hashicorp/go-uuid" + + "github.com/hashicorp/consul/agent/consul/stream" + "github.com/hashicorp/consul/proto/pbsubscribe" +) + +// streamID is used in logs as a unique identifier for a subscription. The value +// is created lazily on the first call to String() so that we do not call it +// if trace logging is disabled. +// If a random UUID can not be created, defaults to the current time formatted +// as RFC3339Nano. +// +// TODO(banks) it might be nice one day to replace this with OpenTracing ID +// if one is set etc. but probably pointless until we support that properly +// in other places so it's actually propagated properly. For now this just +// makes lifetime of a stream more traceable in our regular server logs for +// debugging/dev. +type streamID struct { + once sync.Once + id string +} + +func (s *streamID) String() string { + s.once.Do(func() { + var err error + s.id, err = uuid.GenerateUUID() + if err != nil { + s.id = time.Now().Format(time.RFC3339Nano) + } + }) + return s.id +} + +func (h *Server) newLoggerForRequest(req *pbsubscribe.SubscribeRequest) Logger { + return h.Logger.With( + "topic", req.Topic.String(), + "dc", req.Datacenter, + "key", req.Key, + "index", req.Index, + "stream_id", &streamID{}) +} + +type eventLogger struct { + logger Logger + snapshotDone bool + count uint64 +} + +func (l *eventLogger) Trace(e []stream.Event) { + if len(e) == 0 { + return + } + + first := e[0] + switch { + case first.IsEndOfSnapshot() || first.IsEndOfEmptySnapshot(): + l.snapshotDone = true + l.logger.Trace("snapshot complete", "index", first.Index, "sent", l.count) + + case l.snapshotDone: + l.logger.Trace("sending events", "index", first.Index, "sent", l.count, "batch_size", len(e)) + } + + l.count += uint64(len(e)) +} diff --git a/agent/subscribe/subscribe.go b/agent/subscribe/subscribe.go index 0b1d0a6a94..191638f29e 100644 --- a/agent/subscribe/subscribe.go +++ b/agent/subscribe/subscribe.go @@ -4,7 +4,7 @@ import ( "errors" "fmt" - "github.com/hashicorp/go-uuid" + "github.com/hashicorp/go-hclog" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -23,9 +23,13 @@ type Server struct { Logger Logger } +func NewServer(backend Backend, logger Logger) *Server { + return &Server{Backend: backend, Logger: logger} +} + type Logger interface { - IsTrace() bool Trace(msg string, args ...interface{}) + With(args ...interface{}) hclog.Logger } var _ pbsubscribe.StateChangeSubscriptionServer = (*Server)(nil) @@ -37,41 +41,17 @@ type Backend interface { } func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsubscribe.StateChangeSubscription_SubscribeServer) error { - // streamID is just used for message correlation in trace logs and not - // populated normally. - var streamID string - - if h.Logger.IsTrace() { - // TODO(banks) it might be nice one day to replace this with OpenTracing ID - // if one is set etc. but probably pointless until we support that properly - // in other places so it's actually propagated properly. For now this just - // makes lifetime of a stream more traceable in our regular server logs for - // debugging/dev. - var err error - streamID, err = uuid.GenerateUUID() - if err != nil { - return err - } - } - - // TODO: add fields to logger and pass logger around instead of streamID - handled, err := h.Backend.Forward(req.Datacenter, h.forwardToDC(req, serverStream, streamID)) + logger := h.newLoggerForRequest(req) + handled, err := h.Backend.Forward(req.Datacenter, forwardToDC(req, serverStream, logger)) if handled || err != nil { return err } - h.Logger.Trace("new subscription", - "topic", req.Topic.String(), - "key", req.Key, - "index", req.Index, - "stream_id", streamID, - ) - - var sentCount uint64 - defer h.Logger.Trace("subscription closed", "stream_id", streamID) + logger.Trace("new subscription") + defer logger.Trace("subscription closed") // Resolve the token and create the ACL filter. - // TODO: handle token expiry gracefully... + // TODO(streaming): handle token expiry gracefully... authz, err := h.Backend.ResolveToken(req.Token) if err != nil { return err @@ -84,12 +64,13 @@ func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsub defer sub.Unsubscribe() ctx := serverStream.Context() - snapshotDone := false + + elog := &eventLogger{logger: logger} for { events, err := sub.Next(ctx) switch { case errors.Is(err, stream.ErrSubscriptionClosed): - h.Logger.Trace("subscription reset by server", "stream_id", streamID) + logger.Trace("subscription reset by server") return status.Error(codes.Aborted, err.Error()) case err != nil: return err @@ -100,23 +81,7 @@ func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsub continue } - first := events[0] - switch { - case first.IsEndOfSnapshot() || first.IsEndOfEmptySnapshot(): - snapshotDone = true - h.Logger.Trace("snapshot complete", - "index", first.Index, "sent", sentCount, "stream_id", streamID) - - case snapshotDone: - h.Logger.Trace("sending events", - "index", first.Index, - "sent", sentCount, - "batch_size", len(events), - "stream_id", streamID, - ) - } - - sentCount += uint64(len(events)) + elog.Trace(events) e := newEventFromStreamEvents(req, events) if err := serverStream.Send(e); err != nil { return err @@ -134,26 +99,14 @@ func toStreamSubscribeRequest(req *pbsubscribe.SubscribeRequest) *stream.Subscri } } -func (h *Server) forwardToDC( +func forwardToDC( req *pbsubscribe.SubscribeRequest, serverStream pbsubscribe.StateChangeSubscription_SubscribeServer, - streamID string, + logger Logger, ) func(conn *grpc.ClientConn) error { return func(conn *grpc.ClientConn) error { - h.Logger.Trace("forwarding to another DC", - "dc", req.Datacenter, - "topic", req.Topic.String(), - "key", req.Key, - "index", req.Index, - "stream_id", streamID, - ) - - defer func() { - h.Logger.Trace("forwarded stream closed", - "dc", req.Datacenter, - "stream_id", streamID, - ) - }() + logger.Trace("forwarding to another DC") + defer logger.Trace("forwarded stream closed") client := pbsubscribe.NewStateChangeSubscriptionClient(conn) streamHandle, err := client.Subscribe(serverStream.Context(), req) @@ -175,7 +128,7 @@ func (h *Server) forwardToDC( // filterStreamEvents to only those allowed by the acl token. func filterStreamEvents(authz acl.Authorizer, events []stream.Event) []stream.Event { - // TODO: when is authz nil? + // authz will be nil when ACLs are disabled if authz == nil || len(events) == 0 { return events } diff --git a/agent/subscribe/subscribe_test.go b/agent/subscribe/subscribe_test.go index a005f6eee1..82f1ea6f2b 100644 --- a/agent/subscribe/subscribe_test.go +++ b/agent/subscribe/subscribe_test.go @@ -32,7 +32,7 @@ import ( func TestServer_Subscribe_IntegrationWithBackend(t *testing.T) { backend, err := newTestBackend() require.NoError(t, err) - srv := &Server{Backend: backend, Logger: hclog.New(nil)} + srv := NewServer(backend, hclog.New(nil)) addr := newTestServer(t, srv) ids := newCounter() @@ -373,11 +373,11 @@ func raftIndex(ids *counter, created, modified string) pbcommon.RaftIndex { func TestServer_Subscribe_IntegrationWithBackend_ForwardToDC(t *testing.T) { backendLocal, err := newTestBackend() require.NoError(t, err) - addrLocal := newTestServer(t, &Server{Backend: backendLocal, Logger: hclog.New(nil)}) + addrLocal := newTestServer(t, NewServer(backendLocal, hclog.New(nil))) backendRemoteDC, err := newTestBackend() require.NoError(t, err) - srvRemoteDC := &Server{Backend: backendRemoteDC, Logger: hclog.New(nil)} + srvRemoteDC := NewServer(backendRemoteDC, hclog.New(nil)) addrRemoteDC := newTestServer(t, srvRemoteDC) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -592,7 +592,7 @@ func TestServer_Subscribe_IntegrationWithBackend_FilterEventsByACLToken(t *testi backend, err := newTestBackend() require.NoError(t, err) - srv := &Server{Backend: backend, Logger: hclog.New(nil)} + srv := NewServer(backend, hclog.New(nil)) addr := newTestServer(t, srv) // Create a policy for the test token. @@ -796,7 +796,7 @@ node "node1" { func TestServer_Subscribe_IntegrationWithBackend_ACLUpdate(t *testing.T) { backend, err := newTestBackend() require.NoError(t, err) - srv := &Server{Backend: backend, Logger: hclog.New(nil)} + srv := NewServer(backend, hclog.New(nil)) addr := newTestServer(t, srv) rules := ` From e3290f59713742743b5775139dc021d02d488e97 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Fri, 2 Oct 2020 11:58:18 -0400 Subject: [PATCH 09/11] Move agent/subscribe -> agent/rpc/subscribe --- agent/consul/server.go | 2 +- agent/consul/subscribe_backend.go | 5 +++-- agent/{ => rpc}/subscribe/auth.go | 0 agent/{ => rpc}/subscribe/logger.go | 0 agent/{ => rpc}/subscribe/subscribe.go | 0 agent/{ => rpc}/subscribe/subscribe_test.go | 0 6 files changed, 4 insertions(+), 3 deletions(-) rename agent/{ => rpc}/subscribe/auth.go (100%) rename agent/{ => rpc}/subscribe/logger.go (100%) rename agent/{ => rpc}/subscribe/subscribe.go (100%) rename agent/{ => rpc}/subscribe/subscribe_test.go (100%) diff --git a/agent/consul/server.go b/agent/consul/server.go index dbfa5b4614..ec3cf4ff22 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -40,8 +40,8 @@ import ( "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/router" + "github.com/hashicorp/consul/agent/rpc/subscribe" "github.com/hashicorp/consul/agent/structs" - "github.com/hashicorp/consul/agent/subscribe" "github.com/hashicorp/consul/agent/token" "github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/logging" diff --git a/agent/consul/subscribe_backend.go b/agent/consul/subscribe_backend.go index e7c1ca90a8..56f2bac01a 100644 --- a/agent/consul/subscribe_backend.go +++ b/agent/consul/subscribe_backend.go @@ -1,11 +1,12 @@ package consul import ( + "google.golang.org/grpc" + "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/consul/stream" agentgrpc "github.com/hashicorp/consul/agent/grpc" - "github.com/hashicorp/consul/agent/subscribe" - "google.golang.org/grpc" + "github.com/hashicorp/consul/agent/rpc/subscribe" ) type subscribeBackend struct { diff --git a/agent/subscribe/auth.go b/agent/rpc/subscribe/auth.go similarity index 100% rename from agent/subscribe/auth.go rename to agent/rpc/subscribe/auth.go diff --git a/agent/subscribe/logger.go b/agent/rpc/subscribe/logger.go similarity index 100% rename from agent/subscribe/logger.go rename to agent/rpc/subscribe/logger.go diff --git a/agent/subscribe/subscribe.go b/agent/rpc/subscribe/subscribe.go similarity index 100% rename from agent/subscribe/subscribe.go rename to agent/rpc/subscribe/subscribe.go diff --git a/agent/subscribe/subscribe_test.go b/agent/rpc/subscribe/subscribe_test.go similarity index 100% rename from agent/subscribe/subscribe_test.go rename to agent/rpc/subscribe/subscribe_test.go From f5d11562f2c871a25e30d9442e4d07bdf8db5826 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Fri, 2 Oct 2020 13:55:41 -0400 Subject: [PATCH 10/11] subscribe: update to use NewSnapshotToFollow event --- agent/rpc/subscribe/auth.go | 2 +- agent/rpc/subscribe/logger.go | 5 +++-- agent/rpc/subscribe/subscribe.go | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/agent/rpc/subscribe/auth.go b/agent/rpc/subscribe/auth.go index 81ee3e934e..094ed4e3bf 100644 --- a/agent/rpc/subscribe/auth.go +++ b/agent/rpc/subscribe/auth.go @@ -10,7 +10,7 @@ import ( // event is allowed to be sent to this client or not. func enforceACL(authz acl.Authorizer, e stream.Event) acl.EnforcementDecision { switch { - case e.IsEndOfSnapshot(), e.IsEndOfEmptySnapshot(): + case e.IsEndOfSnapshot(), e.IsNewSnapshotToFollow(): return acl.Allow } diff --git a/agent/rpc/subscribe/logger.go b/agent/rpc/subscribe/logger.go index b1a32a6cdb..9aadf6a40e 100644 --- a/agent/rpc/subscribe/logger.go +++ b/agent/rpc/subscribe/logger.go @@ -59,10 +59,11 @@ func (l *eventLogger) Trace(e []stream.Event) { first := e[0] switch { - case first.IsEndOfSnapshot() || first.IsEndOfEmptySnapshot(): + case first.IsEndOfSnapshot(): l.snapshotDone = true l.logger.Trace("snapshot complete", "index", first.Index, "sent", l.count) - + case first.IsNewSnapshotToFollow(): + return case l.snapshotDone: l.logger.Trace("sending events", "index", first.Index, "sent", l.count, "batch_size", len(e)) } diff --git a/agent/rpc/subscribe/subscribe.go b/agent/rpc/subscribe/subscribe.go index 191638f29e..981c1714b0 100644 --- a/agent/rpc/subscribe/subscribe.go +++ b/agent/rpc/subscribe/subscribe.go @@ -169,8 +169,8 @@ func newEventFromStreamEvents(req *pbsubscribe.SubscribeRequest, events []stream case event.IsEndOfSnapshot(): e.Payload = &pbsubscribe.Event_EndOfSnapshot{EndOfSnapshot: true} return e - case event.IsEndOfEmptySnapshot(): - e.Payload = &pbsubscribe.Event_EndOfEmptySnapshot{EndOfEmptySnapshot: true} + case event.IsNewSnapshotToFollow(): + e.Payload = &pbsubscribe.Event_NewSnapshotToFollow{NewSnapshotToFollow: true} return e } From 21c21191f4cc699ad2a35245b1d8613520df8753 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Tue, 6 Oct 2020 19:08:36 -0400 Subject: [PATCH 11/11] structs: add CheckServiceNode.CanRead And use it from the subscribe endpoint. --- agent/rpc/subscribe/auth.go | 20 +------- agent/rpc/subscribe/subscribe.go | 4 +- agent/structs/structs.go | 20 ++++++++ agent/structs/structs_test.go | 83 ++++++++++++++++++++++++++++++-- 4 files changed, 102 insertions(+), 25 deletions(-) diff --git a/agent/rpc/subscribe/auth.go b/agent/rpc/subscribe/auth.go index 094ed4e3bf..b41b1fdc40 100644 --- a/agent/rpc/subscribe/auth.go +++ b/agent/rpc/subscribe/auth.go @@ -16,25 +16,7 @@ func enforceACL(authz acl.Authorizer, e stream.Event) acl.EnforcementDecision { switch p := e.Payload.(type) { case state.EventPayloadCheckServiceNode: - csn := p.Value - if csn.Node == nil || csn.Service == nil || csn.Node.Node == "" || csn.Service.Service == "" { - return acl.Deny - } - - // TODO: what about acl.Default? - // TODO(streaming): we need the AuthorizerContext for ent - if dec := authz.NodeRead(csn.Node.Node, nil); dec != acl.Allow { - return acl.Deny - } - - // TODO(streaming): we need the AuthorizerContext for ent - // Enterprise support for streaming events - they don't have enough data to - // populate it yet. - if dec := authz.ServiceRead(csn.Service.Service, nil); dec != acl.Allow { - return acl.Deny - } - return acl.Allow + return p.Value.CanRead(authz) } - return acl.Deny } diff --git a/agent/rpc/subscribe/subscribe.go b/agent/rpc/subscribe/subscribe.go index 981c1714b0..bcf87460e1 100644 --- a/agent/rpc/subscribe/subscribe.go +++ b/agent/rpc/subscribe/subscribe.go @@ -35,6 +35,8 @@ type Logger interface { var _ pbsubscribe.StateChangeSubscriptionServer = (*Server)(nil) type Backend interface { + // TODO(streaming): Use ResolveTokenAndDefaultMeta instead once SubscribeRequest + // has an EnterpriseMeta. ResolveToken(token string) (acl.Authorizer, error) Forward(dc string, f func(*grpc.ClientConn) error) (handled bool, err error) Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error) @@ -51,7 +53,6 @@ func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsub defer logger.Trace("subscription closed") // Resolve the token and create the ACL filter. - // TODO(streaming): handle token expiry gracefully... authz, err := h.Backend.ResolveToken(req.Token) if err != nil { return err @@ -64,7 +65,6 @@ func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsub defer sub.Unsubscribe() ctx := serverStream.Context() - elog := &eventLogger{logger: logger} for { events, err := sub.Next(ctx) diff --git a/agent/structs/structs.go b/agent/structs/structs.go index ca730e7449..eef34c3f60 100644 --- a/agent/structs/structs.go +++ b/agent/structs/structs.go @@ -19,6 +19,7 @@ import ( "github.com/hashicorp/serf/coordinate" "github.com/mitchellh/hashstructure" + "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/lib" @@ -1576,6 +1577,25 @@ func (csn *CheckServiceNode) BestAddress(wan bool) (string, int) { return addr, port } +func (csn *CheckServiceNode) CanRead(authz acl.Authorizer) acl.EnforcementDecision { + if csn.Node == nil || csn.Service == nil { + return acl.Deny + } + + // TODO(streaming): add enterprise test that uses namespaces + authzContext := new(acl.AuthorizerContext) + csn.Service.FillAuthzContext(authzContext) + + if authz.NodeRead(csn.Node.Node, authzContext) != acl.Allow { + return acl.Deny + } + + if authz.ServiceRead(csn.Service.Service, authzContext) != acl.Allow { + return acl.Deny + } + return acl.Allow +} + type CheckServiceNodes []CheckServiceNode // Shuffle does an in-place random shuffle using the Fisher-Yates algorithm. diff --git a/agent/structs/structs_test.go b/agent/structs/structs_test.go index a9c75bca00..0b4e9c497e 100644 --- a/agent/structs/structs_test.go +++ b/agent/structs/structs_test.go @@ -8,13 +8,15 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/types" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestEncodeDecode(t *testing.T) { @@ -1152,7 +1154,7 @@ func TestStructs_HealthCheck_Clone(t *testing.T) { } } -func TestStructs_CheckServiceNodes_Shuffle(t *testing.T) { +func TestCheckServiceNodes_Shuffle(t *testing.T) { // Make a huge list of nodes. var nodes CheckServiceNodes for i := 0; i < 100; i++ { @@ -1185,7 +1187,7 @@ func TestStructs_CheckServiceNodes_Shuffle(t *testing.T) { } } -func TestStructs_CheckServiceNodes_Filter(t *testing.T) { +func TestCheckServiceNodes_Filter(t *testing.T) { nodes := CheckServiceNodes{ CheckServiceNode{ Node: &Node{ @@ -1288,6 +1290,79 @@ func TestStructs_CheckServiceNodes_Filter(t *testing.T) { } } +func TestCheckServiceNodes_CanRead(t *testing.T) { + type testCase struct { + name string + csn CheckServiceNode + authz acl.Authorizer + expected acl.EnforcementDecision + } + + fn := func(t *testing.T, tc testCase) { + actual := tc.csn.CanRead(tc.authz) + require.Equal(t, tc.expected, actual) + } + + var testCases = []testCase{ + { + name: "empty", + expected: acl.Deny, + }, + { + name: "node read not authorized", + csn: CheckServiceNode{ + Node: &Node{Node: "name"}, + Service: &NodeService{Service: "service-name"}, + }, + authz: aclAuthorizerCheckServiceNode{allowService: true}, + expected: acl.Deny, + }, + { + name: "service read not authorized", + csn: CheckServiceNode{ + Node: &Node{Node: "name"}, + Service: &NodeService{Service: "service-name"}, + }, + authz: aclAuthorizerCheckServiceNode{allowNode: true}, + expected: acl.Deny, + }, + { + name: "read authorized", + csn: CheckServiceNode{ + Node: &Node{Node: "name"}, + Service: &NodeService{Service: "service-name"}, + }, + authz: acl.AllowAll(), + expected: acl.Allow, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + fn(t, tc) + }) + } +} + +type aclAuthorizerCheckServiceNode struct { + acl.Authorizer + allowNode bool + allowService bool +} + +func (a aclAuthorizerCheckServiceNode) ServiceRead(string, *acl.AuthorizerContext) acl.EnforcementDecision { + if a.allowService { + return acl.Allow + } + return acl.Deny +} + +func (a aclAuthorizerCheckServiceNode) NodeRead(string, *acl.AuthorizerContext) acl.EnforcementDecision { + if a.allowNode { + return acl.Allow + } + return acl.Deny +} + func TestStructs_DirEntry_Clone(t *testing.T) { e := &DirEntry{ LockIndex: 5,