diff --git a/agent/subscribe/subscribe.go b/agent/subscribe/subscribe.go index 3fb2ae2672..7908a410b3 100644 --- a/agent/subscribe/subscribe.go +++ b/agent/subscribe/subscribe.go @@ -4,15 +4,15 @@ import ( "errors" "fmt" - "github.com/hashicorp/consul/agent/consul/state" - "github.com/hashicorp/consul/proto/pbservice" "github.com/hashicorp/go-uuid" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/consul/stream" + "github.com/hashicorp/consul/proto/pbservice" "github.com/hashicorp/consul/proto/pbsubscribe" ) @@ -107,6 +107,7 @@ func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsub snapshotDone = true h.Logger.Trace("snapshot complete", "index", first.Index, "sent", sentCount, "stream_id", streamID) + case snapshotDone: h.Logger.Trace("sending events", "index", first.Index, @@ -208,8 +209,20 @@ func newEventFromStreamEvents(req *pbsubscribe.SubscribeRequest, events []stream Key: req.Key, Index: events[0].Index, } + if len(events) == 1 { - setPayload(e, events[0].Payload) + event := events[0] + // TODO: refactor so these are only checked once, instead of 3 times. + switch { + case event.IsEndOfSnapshot(): + e.Payload = &pbsubscribe.Event_EndOfSnapshot{EndOfSnapshot: true} + return e + case event.IsEndOfEmptySnapshot(): + e.Payload = &pbsubscribe.Event_EndOfEmptySnapshot{EndOfEmptySnapshot: true} + return e + } + + setPayload(e, event.Payload) return e } diff --git a/agent/subscribe/subscribe_test.go b/agent/subscribe/subscribe_test.go index 3c490bd323..bac2ff2303 100644 --- a/agent/subscribe/subscribe_test.go +++ b/agent/subscribe/subscribe_test.go @@ -1,35 +1,42 @@ package subscribe -/* TODO -func TestStreaming_Subscribe(t *testing.T) { - t.Parallel() +import ( + "context" + "errors" + "io" + "net" + "testing" + "time" - require := require.New(t) - dir1, server := testServerWithConfig(t, func(c *Config) { - c.Datacenter = "dc1" - c.Bootstrap = true - c.GRPCEnabled = true - }) - defer os.RemoveAll(dir1) - defer server.Shutdown() - codec := rpcClient(t, server) - defer codec.Close() + "github.com/google/go-cmp/cmp" + "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + gogrpc "google.golang.org/grpc" - dir2, client := testClientWithConfig(t, func(c *Config) { - c.Datacenter = "dc1" - c.NodeName = uniqueNodeName(t.Name()) - c.GRPCEnabled = true - }) - defer os.RemoveAll(dir2) - defer client.Shutdown() + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/consul/state" + "github.com/hashicorp/consul/agent/consul/stream" + "github.com/hashicorp/consul/agent/grpc" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/proto/pbcommon" + "github.com/hashicorp/consul/proto/pbservice" + "github.com/hashicorp/consul/proto/pbsubscribe" + "github.com/hashicorp/consul/types" +) - // Try to join - testrpc.WaitForLeader(t, server.RPC, "dc1") - joinLAN(t, client, server) - testrpc.WaitForTestAgent(t, client.RPC, "dc1") +func TestServer_Subscribe_IntegrationWithBackend(t *testing.T) { + if testing.Short() { + t.Skip("too slow for -short run") + } + + backend, err := newTestBackend() + require.NoError(t, err) + srv := &Server{Backend: backend, Logger: hclog.New(nil)} + + addr := newTestServer(t, srv) + ids := newCounter() - // Register a dummy node with a service we don't care about, to make sure - // we don't see updates for it. { req := &structs.RegisterRequest{ Node: "other", @@ -42,11 +49,8 @@ func TestStreaming_Subscribe(t *testing.T) { Port: 9000, }, } - var out struct{} - require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + require.NoError(t, backend.store.EnsureRegistration(ids.Next("other"), req)) } - - // Register a dummy node with our service on it. { req := &structs.RegisterRequest{ Node: "node1", @@ -59,11 +63,8 @@ func TestStreaming_Subscribe(t *testing.T) { Port: 8080, }, } - var out struct{} - require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + require.NoError(t, backend.store.EnsureRegistration(ids.Next("reg2"), req)) } - - // Register a test node to be updated later. req := &structs.RegisterRequest{ Node: "node2", Address: "1.2.3.4", @@ -75,62 +76,58 @@ func TestStreaming_Subscribe(t *testing.T) { Port: 8080, }, } - var out struct{} - require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + require.NoError(t, backend.store.EnsureRegistration(ids.Next("reg3"), req)) - // Start a Subscribe call to our streaming endpoint. - conn, err := client.GRPCConn() - require.NoError(err) - - streamClient := pbsubscribe.NewConsulClient(conn) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() + conn, err := gogrpc.DialContext(ctx, addr.String(), gogrpc.WithInsecure()) + require.NoError(t, err) + t.Cleanup(logError(t, conn.Close)) + + streamClient := pbsubscribe.NewStateChangeSubscriptionClient(conn) streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{ Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis", }) - require.NoError(err) + require.NoError(t, err) - // Start a goroutine to read updates off the pbsubscribe. - eventCh := make(chan *pbsubscribe.Event, 0) - go testSendEvents(t, eventCh, streamHandle) + chEvents := make(chan eventOrError, 0) + go recvEvents(chEvents, streamHandle) var snapshotEvents []*pbsubscribe.Event for i := 0; i < 3; i++ { - select { - case event := <-eventCh: - snapshotEvents = append(snapshotEvents, event) - case <-time.After(3 * time.Second): - t.Fatalf("did not receive events past %d", len(snapshotEvents)) - } + snapshotEvents = append(snapshotEvents, getEvent(t, chEvents)) } expected := []*pbsubscribe.Event{ { Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis", + Index: ids.Last(), Payload: &pbsubscribe.Event_ServiceHealth{ ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ Op: pbsubscribe.CatalogOp_Register, - CheckServiceNode: &pbsubscribe.CheckServiceNode{ - Node: &pbsubscribe.Node{ + CheckServiceNode: &pbservice.CheckServiceNode{ + Node: &pbservice.Node{ Node: "node1", Datacenter: "dc1", Address: "3.4.5.6", + RaftIndex: raftIndex(ids, "reg2", "reg2"), }, - Service: &pbsubscribe.NodeService{ + Service: &pbservice.NodeService{ ID: "redis1", Service: "redis", Address: "3.4.5.6", Port: 8080, - Weights: &pbsubscribe.Weights{Passing: 1, Warning: 1}, + Weights: &pbservice.Weights{Passing: 1, Warning: 1}, // Sad empty state - Proxy: pbsubscribe.ConnectProxyConfig{ - MeshGateway: &pbsubscribe.MeshGatewayConfig{}, - Expose: &pbsubscribe.ExposeConfig{}, + Proxy: pbservice.ConnectProxyConfig{ + MeshGateway: pbservice.MeshGatewayConfig{}, + Expose: pbservice.ExposeConfig{}, }, - EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, + RaftIndex: raftIndex(ids, "reg2", "reg2"), + EnterpriseMeta: pbcommon.EnterpriseMeta{}, }, }, }, @@ -139,27 +136,30 @@ func TestStreaming_Subscribe(t *testing.T) { { Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis", + Index: ids.Last(), Payload: &pbsubscribe.Event_ServiceHealth{ ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ Op: pbsubscribe.CatalogOp_Register, - CheckServiceNode: &pbsubscribe.CheckServiceNode{ - Node: &pbsubscribe.Node{ + CheckServiceNode: &pbservice.CheckServiceNode{ + Node: &pbservice.Node{ Node: "node2", Datacenter: "dc1", Address: "1.2.3.4", + RaftIndex: raftIndex(ids, "reg3", "reg3"), }, - Service: &pbsubscribe.NodeService{ + Service: &pbservice.NodeService{ ID: "redis1", Service: "redis", Address: "1.1.1.1", Port: 8080, - Weights: &pbsubscribe.Weights{Passing: 1, Warning: 1}, + Weights: &pbservice.Weights{Passing: 1, Warning: 1}, // Sad empty state - Proxy: pbsubscribe.ConnectProxyConfig{ - MeshGateway: &pbsubscribe.MeshGatewayConfig{}, - Expose: &pbsubscribe.ExposeConfig{}, + Proxy: pbservice.ConnectProxyConfig{ + MeshGateway: pbservice.MeshGatewayConfig{}, + Expose: pbservice.ExposeConfig{}, }, - EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, + RaftIndex: raftIndex(ids, "reg3", "reg3"), + EnterpriseMeta: pbcommon.EnterpriseMeta{}, }, }, }, @@ -168,22 +168,11 @@ func TestStreaming_Subscribe(t *testing.T) { { Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis", + Index: ids.Last(), Payload: &pbsubscribe.Event_EndOfSnapshot{EndOfSnapshot: true}, }, } - - require.Len(snapshotEvents, 3) - for i := 0; i < 2; i++ { - // Fix up the index - expected[i].Index = snapshotEvents[i].Index - node := expected[i].GetServiceHealth().CheckServiceNode - node.Node.RaftIndex = snapshotEvents[i].GetServiceHealth().CheckServiceNode.Node.RaftIndex - node.Service.RaftIndex = snapshotEvents[i].GetServiceHealth().CheckServiceNode.Service.RaftIndex - } - // Fix index on snapshot event - expected[2].Index = snapshotEvents[2].Index - - requireEqualProtos(t, expected, snapshotEvents) + assertDeepEqual(t, expected, snapshotEvents) // Update the registration by adding a check. req.Check = &structs.HealthCheck{ @@ -193,73 +182,189 @@ func TestStreaming_Subscribe(t *testing.T) { ServiceName: "redis", Name: "check 1", } - require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &out)) + require.NoError(t, backend.store.EnsureRegistration(ids.Next("update"), req)) - // Make sure we get the event for the diff. - select { - case event := <-eventCh: - expected := &pbsubscribe.Event{ - Topic: pbsubscribe.Topic_ServiceHealth, - Key: "redis", - Payload: &pbsubscribe.Event_ServiceHealth{ - ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ - Op: pbsubscribe.CatalogOp_Register, - CheckServiceNode: &pbsubscribe.CheckServiceNode{ - Node: &pbsubscribe.Node{ - Node: "node2", - Datacenter: "dc1", - Address: "1.2.3.4", - RaftIndex: pbsubscribe.RaftIndex{CreateIndex: 13, ModifyIndex: 13}, + event := getEvent(t, chEvents) + expectedEvent := &pbsubscribe.Event{ + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "redis", + Index: ids.Last(), + Payload: &pbsubscribe.Event_ServiceHealth{ + ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ + Op: pbsubscribe.CatalogOp_Register, + CheckServiceNode: &pbservice.CheckServiceNode{ + Node: &pbservice.Node{ + Node: "node2", + Datacenter: "dc1", + Address: "1.2.3.4", + RaftIndex: raftIndex(ids, "reg3", "reg3"), + }, + Service: &pbservice.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + Weights: &pbservice.Weights{Passing: 1, Warning: 1}, + // Sad empty state + Proxy: pbservice.ConnectProxyConfig{ + MeshGateway: pbservice.MeshGatewayConfig{}, + Expose: pbservice.ExposeConfig{}, }, - Service: &pbsubscribe.NodeService{ - ID: "redis1", - Service: "redis", - Address: "1.1.1.1", - Port: 8080, - RaftIndex: pbsubscribe.RaftIndex{CreateIndex: 13, ModifyIndex: 13}, - Weights: &pbsubscribe.Weights{Passing: 1, Warning: 1}, - // Sad empty state - Proxy: pbsubscribe.ConnectProxyConfig{ - MeshGateway: &pbsubscribe.MeshGatewayConfig{}, - Expose: &pbsubscribe.ExposeConfig{}, - }, - EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, - }, - Checks: []*pbsubscribe.HealthCheck{ - { - CheckID: "check1", - Name: "check 1", - Node: "node2", - Status: "critical", - ServiceID: "redis1", - ServiceName: "redis", - RaftIndex: pbsubscribe.RaftIndex{CreateIndex: 14, ModifyIndex: 14}, - EnterpriseMeta: &pbsubscribe.EnterpriseMeta{}, - }, + RaftIndex: raftIndex(ids, "reg3", "reg3"), + EnterpriseMeta: pbcommon.EnterpriseMeta{}, + }, + Checks: []*pbservice.HealthCheck{ + { + CheckID: "check1", + Name: "check 1", + Node: "node2", + Status: "critical", + ServiceID: "redis1", + ServiceName: "redis", + RaftIndex: raftIndex(ids, "update", "update"), + EnterpriseMeta: pbcommon.EnterpriseMeta{}, }, }, }, }, - } - // Fix up the index - expected.Index = event.Index - node := expected.GetServiceHealth().CheckServiceNode - node.Node.RaftIndex = event.GetServiceHealth().CheckServiceNode.Node.RaftIndex - node.Service.RaftIndex = event.GetServiceHealth().CheckServiceNode.Service.RaftIndex - node.Checks[0].RaftIndex = event.GetServiceHealth().CheckServiceNode.Checks[0].RaftIndex - requireEqualProtos(t, expected, event) - case <-time.After(3 * time.Second): - t.Fatal("never got event") + }, } + assertDeepEqual(t, expectedEvent, event) +} - // Wait and make sure there aren't any more events coming. - select { - case event := <-eventCh: - t.Fatalf("got another event: %v", event) - case <-time.After(500 * time.Millisecond): +type eventOrError struct { + event *pbsubscribe.Event + err error +} + +// recvEvents from handle and sends them to the provided channel. +func recvEvents(ch chan eventOrError, handle pbsubscribe.StateChangeSubscription_SubscribeClient) { + defer close(ch) + for { + event, err := handle.Recv() + switch { + case errors.Is(err, io.EOF): + return + case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): + return + case err != nil: + ch <- eventOrError{err: err} + return + default: + ch <- eventOrError{event: event} + } } } +func getEvent(t *testing.T, ch chan eventOrError) *pbsubscribe.Event { + t.Helper() + select { + case item := <-ch: + require.NoError(t, item.err) + return item.event + case <-time.After(10 * time.Second): + t.Fatalf("timeout waiting on event from server") + } + return nil +} + +func assertDeepEqual(t *testing.T, x, y interface{}) { + t.Helper() + if diff := cmp.Diff(x, y); diff != "" { + t.Fatalf("assertion failed: values are not equal\n--- expected\n+++ actual\n%v", diff) + } +} + +type testBackend struct { + store *state.Store + authorizer acl.Authorizer +} + +func (b testBackend) ResolveToken(_ string) (acl.Authorizer, error) { + return b.authorizer, nil +} + +func (b testBackend) Forward(_ string, _ func(*gogrpc.ClientConn) error) (handled bool, err error) { + return false, nil +} + +func (b testBackend) Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error) { + return b.store.EventPublisher().Subscribe(req) +} + +func newTestBackend() (*testBackend, error) { + gc, err := state.NewTombstoneGC(time.Second, time.Millisecond) + if err != nil { + return nil, err + } + store, err := state.NewStateStore(gc) + if err != nil { + return nil, err + } + return &testBackend{store: store, authorizer: acl.AllowAll()}, nil +} + +var _ Backend = (*testBackend)(nil) + +func newTestServer(t *testing.T, server *Server) net.Addr { + addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} + var grpcServer *gogrpc.Server + handler := grpc.NewHandler(addr, func(srv *gogrpc.Server) { + grpcServer = srv + pbsubscribe.RegisterStateChangeSubscriptionServer(srv, server) + }) + + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(logError(t, lis.Close)) + + go grpcServer.Serve(lis) + g := new(errgroup.Group) + g.Go(func() error { + return grpcServer.Serve(lis) + }) + t.Cleanup(func() { + if err := handler.Shutdown(); err != nil { + t.Logf("grpc server shutdown: %v", err) + } + if err := g.Wait(); err != nil { + t.Logf("grpc server error: %v", err) + } + }) + return lis.Addr() +} + +type counter struct { + value uint64 + labels map[string]uint64 +} + +func (c *counter) Next(label string) uint64 { + c.value++ + c.labels[label] = c.value + return c.value +} + +func (c *counter) For(label string) uint64 { + return c.labels[label] +} + +func (c *counter) Last() uint64 { + return c.value +} + +func newCounter() *counter { + return &counter{labels: make(map[string]uint64)} +} + +func raftIndex(ids *counter, created, modified string) pbcommon.RaftIndex { + return pbcommon.RaftIndex{ + CreateIndex: ids.For(created), + ModifyIndex: ids.For(modified), + } +} + +/* TODO func TestStreaming_Subscribe_MultiDC(t *testing.T) { t.Parallel() @@ -365,7 +470,7 @@ func TestStreaming_Subscribe_MultiDC(t *testing.T) { // Start a goroutine to read updates off the pbsubscribe. eventCh := make(chan *pbsubscribe.Event, 0) - go testSendEvents(t, eventCh, streamHandle) + go recvEvents(t, eventCh, streamHandle) var snapshotEvents []*pbsubscribe.Event for i := 0; i < 3; i++ { @@ -452,7 +557,7 @@ func TestStreaming_Subscribe_MultiDC(t *testing.T) { node.Service.RaftIndex = snapshotEvents[i].GetServiceHealth().CheckServiceNode.Service.RaftIndex } expected[2].Index = snapshotEvents[2].Index - requireEqualProtos(t, expected, snapshotEvents) + assertDeepEqual(t, expected, snapshotEvents) // Update the registration by adding a check. req.Check = &structs.HealthCheck{ @@ -516,7 +621,7 @@ func TestStreaming_Subscribe_MultiDC(t *testing.T) { node.Node.RaftIndex = event.GetServiceHealth().CheckServiceNode.Node.RaftIndex node.Service.RaftIndex = event.GetServiceHealth().CheckServiceNode.Service.RaftIndex node.Checks[0].RaftIndex = event.GetServiceHealth().CheckServiceNode.Checks[0].RaftIndex - requireEqualProtos(t, expected, event) + assertDeepEqual(t, expected, event) case <-time.After(3 * time.Second): t.Fatal("never got event") } @@ -588,7 +693,7 @@ func TestStreaming_Subscribe_SkipSnapshot(t *testing.T) { // Start a goroutine to read updates off the pbsubscribe. eventCh := make(chan *pbsubscribe.Event, 0) - go testSendEvents(t, eventCh, streamHandle) + go recvEvents(t, eventCh, streamHandle) var snapshotEvents []*pbsubscribe.Event for i := 0; i < 2; i++ { @@ -617,7 +722,7 @@ func TestStreaming_Subscribe_SkipSnapshot(t *testing.T) { // Start a goroutine to read updates off the pbsubscribe. eventCh := make(chan *pbsubscribe.Event, 0) - go testSendEvents(t, eventCh, streamHandle) + go recvEvents(t, eventCh, streamHandle) // We should get no snapshot and the first event should be "resume stream" select { @@ -737,7 +842,7 @@ func TestStreaming_Subscribe_FilterACL(t *testing.T) { // Start a goroutine to read updates off the pbsubscribe. eventCh := make(chan *pbsubscribe.Event, 0) - go testSendEvents(t, eventCh, streamHandle) + go recvEvents(t, eventCh, streamHandle) // Read events off the pbsubscribe. We should not see any events for the filtered node. var snapshotEvents []*pbsubscribe.Event @@ -823,7 +928,7 @@ func TestStreaming_Subscribe_FilterACL(t *testing.T) { // Start a goroutine to read updates off the pbsubscribe. eventCh := make(chan *pbsubscribe.Event, 0) - go testSendEvents(t, eventCh, streamHandle) + go recvEvents(t, eventCh, streamHandle) select { case event := <-eventCh: @@ -939,7 +1044,7 @@ func TestStreaming_Subscribe_ACLUpdate(t *testing.T) { // Start a goroutine to read updates off the pbsubscribe. eventCh := make(chan *pbsubscribe.Event, 0) - go testSendEvents(t, eventCh, streamHandle) + go recvEvents(t, eventCh, streamHandle) // Read events off the pbsubscribe. var snapshotEvents []*pbsubscribe.Event @@ -998,25 +1103,6 @@ func TestStreaming_Subscribe_ACLUpdate(t *testing.T) { } } -// testSendEvents receives pbsubscribe.Events from a given handle and sends them to the provided -// channel. This is meant to be run in a separate goroutine from the main test. -func testSendEvents(t *testing.T, ch chan *pbsubscribe.Event, handle pbsubscribe.StateChangeSubscription_SubscribeClient) { - for { - event, err := handle.Recv() - if err == io.EOF { - break - } - if err != nil { - if strings.Contains(err.Error(), "context deadline exceeded") || - strings.Contains(err.Error(), "context canceled") { - break - } - t.Log(err) - } - ch <- event - } -} - func TestStreaming_TLSEnabled(t *testing.T) { t.Parallel() @@ -1079,7 +1165,7 @@ func TestStreaming_TLSEnabled(t *testing.T) { // Start a goroutine to read updates off the pbsubscribe. eventCh := make(chan *pbsubscribe.Event, 0) - go testSendEvents(t, eventCh, streamHandle) + go recvEvents(t, eventCh, streamHandle) var snapshotEvents []*pbsubscribe.Event for i := 0; i < 2; i++ { @@ -1110,7 +1196,7 @@ func TestStreaming_TLSEnabled(t *testing.T) { // Start a goroutine to read updates off the pbsubscribe. eventCh := make(chan *pbsubscribe.Event, 0) - go testSendEvents(t, eventCh, streamHandle) + go recvEvents(t, eventCh, streamHandle) var snapshotEvents []*pbsubscribe.Event for i := 0; i < 2; i++ { @@ -1418,17 +1504,12 @@ func svcOrErr(event *pbsubscribe.Event) (*pbservice.NodeService, error) { } return csn.Service, nil } - -// requireEqualProtos is a helper that runs arrays or structures containing -// proto buf messages through JSON encoding before comparing/diffing them. This -// is necessary because require.Equal doesn't compare them equal and generates -// really unhelpful output in this case for some reason. -func requireEqualProtos(t *testing.T, want, got interface{}) { - t.Helper() - gotJSON, err := json.Marshal(got) - require.NoError(t, err) - expectJSON, err := json.Marshal(want) - require.NoError(t, err) - require.JSONEq(t, string(expectJSON), string(gotJSON)) -} */ + +func logError(t *testing.T, f func() error) func() { + return func() { + if err := f(); err != nil { + t.Logf(err.Error()) + } + } +}