diff --git a/agent/grpc-external/services/peerstream/server.go b/agent/grpc-external/services/peerstream/server.go index a71c30d31a..96694d63ef 100644 --- a/agent/grpc-external/services/peerstream/server.go +++ b/agent/grpc-external/services/peerstream/server.go @@ -1,6 +1,8 @@ package peerstream import ( + "time" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-memdb" "google.golang.org/grpc" @@ -17,6 +19,11 @@ import ( // TODO(peering): fix up these interfaces to be more testable now that they are // extracted from private peering +const ( + defaultOutgoingHeartbeatInterval = 15 * time.Second + defaultIncomingHeartbeatTimeout = 2 * time.Minute +) + type Server struct { Config } @@ -30,6 +37,12 @@ type Config struct { // Datacenter of the Consul server this gRPC server is hosted on Datacenter string ConnectEnabled bool + + // outgoingHeartbeatInterval is how often we send a heartbeat. + outgoingHeartbeatInterval time.Duration + + // incomingHeartbeatTimeout is how long we'll wait between receiving heartbeats before we close the connection. + incomingHeartbeatTimeout time.Duration } //go:generate mockery --name ACLResolver --inpackage @@ -46,6 +59,12 @@ func NewServer(cfg Config) *Server { if cfg.Datacenter == "" { panic("Datacenter is required") } + if cfg.outgoingHeartbeatInterval == 0 { + cfg.outgoingHeartbeatInterval = defaultOutgoingHeartbeatInterval + } + if cfg.incomingHeartbeatTimeout == 0 { + cfg.incomingHeartbeatTimeout = defaultIncomingHeartbeatTimeout + } return &Server{ Config: cfg, } diff --git a/agent/grpc-external/services/peerstream/stream_resources.go b/agent/grpc-external/services/peerstream/stream_resources.go index 5c69d08a72..3d10cdfa0c 100644 --- a/agent/grpc-external/services/peerstream/stream_resources.go +++ b/agent/grpc-external/services/peerstream/stream_resources.go @@ -6,6 +6,7 @@ import ( "io" "strings" "sync" + "time" "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" @@ -266,6 +267,40 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error { } }() + // Heartbeat sender. + go func() { + tick := time.NewTicker(s.outgoingHeartbeatInterval) + defer tick.Stop() + + for { + select { + case <-streamReq.Stream.Context().Done(): + return + + case <-tick.C: + } + + heartbeat := &pbpeerstream.ReplicationMessage{ + Payload: &pbpeerstream.ReplicationMessage_Heartbeat_{ + Heartbeat: &pbpeerstream.ReplicationMessage_Heartbeat{}, + }, + } + if err := streamSend(heartbeat); err != nil { + logger.Warn("error sending heartbeat", "err", err) + } + } + }() + + // incomingHeartbeatCtx will complete if incoming heartbeats time out. + incomingHeartbeatCtx, incomingHeartbeatCtxCancel := + context.WithTimeout(context.Background(), s.incomingHeartbeatTimeout) + // NOTE: It's important that we wrap the call to cancel in a wrapper func because during the loop we're + // re-assigning the value of incomingHeartbeatCtxCancel and we want the defer to run on the last assigned + // value, not the current value. + defer func() { + incomingHeartbeatCtxCancel() + }() + for { select { // When the doneCh is closed that means that the peering was deleted locally. @@ -278,6 +313,9 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error { }, } if err := streamSend(term); err != nil { + // Nolint directive needed due to bug in govet that doesn't see that the cancel + // func of the incomingHeartbeatTimer _does_ get called. + //nolint:govet return fmt.Errorf("failed to send to stream: %v", err) } @@ -286,6 +324,11 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error { return nil + // We haven't received a heartbeat within the expected interval. Kill the stream. + case <-incomingHeartbeatCtx.Done(): + logger.Error("ending stream due to heartbeat timeout") + return fmt.Errorf("heartbeat timeout") + case msg, open := <-recvChan: if !open { // The only time we expect the stream to end is when we've received a "Terminated" message. @@ -431,6 +474,20 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error { return nil } + if msg.GetHeartbeat() != nil { + // Reset the heartbeat timeout by creating a new context. + // We first must cancel the old context so there's no leaks. This is safe to do because we're only + // reading that context within this for{} loop, and so we won't accidentally trigger the heartbeat + // timeout. + incomingHeartbeatCtxCancel() + // NOTE: IDEs and govet think that the reassigned cancel below never gets + // called, but it does by the defer when the heartbeat ctx is first created. + // They just can't trace the execution properly for some reason (possibly golang/go#29587). + //nolint:govet + incomingHeartbeatCtx, incomingHeartbeatCtxCancel = + context.WithTimeout(context.Background(), s.incomingHeartbeatTimeout) + } + case update := <-subCh: var resp *pbpeerstream.ReplicationMessage_Response switch { diff --git a/agent/grpc-external/services/peerstream/stream_test.go b/agent/grpc-external/services/peerstream/stream_test.go index d5f9e2c36c..41c24dc1b4 100644 --- a/agent/grpc-external/services/peerstream/stream_test.go +++ b/agent/grpc-external/services/peerstream/stream_test.go @@ -272,7 +272,6 @@ func TestStreamResources_Server_FirstRequest(t *testing.T) { run(t, tc) }) } - } func TestStreamResources_Server_Terminate(t *testing.T) { @@ -869,6 +868,197 @@ func TestStreamResources_Server_CARootUpdates(t *testing.T) { }) } +// Test that when the client doesn't send a heartbeat in time, the stream is terminated. +func TestStreamResources_Server_TerminatesOnHeartbeatTimeout(t *testing.T) { + it := incrementalTime{ + base: time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC), + } + + srv, store := newTestServer(t, func(c *Config) { + c.Tracker.SetClock(it.Now) + c.incomingHeartbeatTimeout = 5 * time.Millisecond + }) + + p := writePeeringToBeDialed(t, store, 1, "my-peer") + require.Empty(t, p.PeerID, "should be empty if being dialed") + peerID := p.ID + + // Set the initial roots and CA configuration. + _, _ = writeInitialRootsAndCA(t, store) + + client := makeClient(t, srv, peerID) + + // TODO(peering): test fails if we don't drain the stream with this call because the + // server gets blocked sending the termination message. Figure out a way to let + // messages queue and filter replication messages. + receiveRoots, err := client.Recv() + require.NoError(t, err) + require.NotNil(t, receiveRoots.GetResponse()) + require.Equal(t, pbpeerstream.TypeURLPeeringTrustBundle, receiveRoots.GetResponse().ResourceURL) + + testutil.RunStep(t, "new stream gets tracked", func(t *testing.T) { + retry.Run(t, func(r *retry.R) { + status, ok := srv.StreamStatus(peerID) + require.True(r, ok) + require.True(r, status.Connected) + }) + }) + + testutil.RunStep(t, "stream is disconnected due to heartbeat timeout", func(t *testing.T) { + retry.Run(t, func(r *retry.R) { + status, ok := srv.StreamStatus(peerID) + require.True(r, ok) + require.False(r, status.Connected) + }) + }) +} + +// Test that the server sends heartbeats at the expected interval. +func TestStreamResources_Server_SendsHeartbeats(t *testing.T) { + it := incrementalTime{ + base: time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC), + } + outgoingHeartbeatInterval := 5 * time.Millisecond + + srv, store := newTestServer(t, func(c *Config) { + c.Tracker.SetClock(it.Now) + c.outgoingHeartbeatInterval = outgoingHeartbeatInterval + }) + + p := writePeeringToBeDialed(t, store, 1, "my-peer") + require.Empty(t, p.PeerID, "should be empty if being dialed") + peerID := p.ID + + // Set the initial roots and CA configuration. + _, _ = writeInitialRootsAndCA(t, store) + + client := makeClient(t, srv, peerID) + + // TODO(peering): test fails if we don't drain the stream with this call because the + // server gets blocked sending the termination message. Figure out a way to let + // messages queue and filter replication messages. + receiveRoots, err := client.Recv() + require.NoError(t, err) + require.NotNil(t, receiveRoots.GetResponse()) + require.Equal(t, pbpeerstream.TypeURLPeeringTrustBundle, receiveRoots.GetResponse().ResourceURL) + + testutil.RunStep(t, "new stream gets tracked", func(t *testing.T) { + retry.Run(t, func(r *retry.R) { + status, ok := srv.StreamStatus(peerID) + require.True(r, ok) + require.True(r, status.Connected) + }) + }) + + testutil.RunStep(t, "sends first heartbeat", func(t *testing.T) { + retry.RunWith(&retry.Timer{ + Timeout: outgoingHeartbeatInterval * 2, + Wait: outgoingHeartbeatInterval / 2, + }, t, func(r *retry.R) { + heartbeat, err := client.Recv() + require.NoError(t, err) + require.NotNil(t, heartbeat.GetHeartbeat()) + }) + }) + + testutil.RunStep(t, "sends second heartbeat", func(t *testing.T) { + retry.RunWith(&retry.Timer{ + Timeout: outgoingHeartbeatInterval * 2, + Wait: outgoingHeartbeatInterval / 2, + }, t, func(r *retry.R) { + heartbeat, err := client.Recv() + require.NoError(t, err) + require.NotNil(t, heartbeat.GetHeartbeat()) + }) + }) +} + +// Test that as long as the server receives heartbeats it keeps the connection open. +func TestStreamResources_Server_KeepsConnectionOpenWithHeartbeat(t *testing.T) { + it := incrementalTime{ + base: time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC), + } + incomingHeartbeatTimeout := 10 * time.Millisecond + + srv, store := newTestServer(t, func(c *Config) { + c.Tracker.SetClock(it.Now) + c.incomingHeartbeatTimeout = incomingHeartbeatTimeout + }) + + p := writePeeringToBeDialed(t, store, 1, "my-peer") + require.Empty(t, p.PeerID, "should be empty if being dialed") + peerID := p.ID + + // Set the initial roots and CA configuration. + _, _ = writeInitialRootsAndCA(t, store) + + client := makeClient(t, srv, peerID) + + // TODO(peering): test fails if we don't drain the stream with this call because the + // server gets blocked sending the termination message. Figure out a way to let + // messages queue and filter replication messages. + receiveRoots, err := client.Recv() + require.NoError(t, err) + require.NotNil(t, receiveRoots.GetResponse()) + require.Equal(t, pbpeerstream.TypeURLPeeringTrustBundle, receiveRoots.GetResponse().ResourceURL) + + testutil.RunStep(t, "new stream gets tracked", func(t *testing.T) { + retry.Run(t, func(r *retry.R) { + status, ok := srv.StreamStatus(peerID) + require.True(r, ok) + require.True(r, status.Connected) + }) + }) + + heartbeatMsg := &pbpeerstream.ReplicationMessage{ + Payload: &pbpeerstream.ReplicationMessage_Heartbeat_{ + Heartbeat: &pbpeerstream.ReplicationMessage_Heartbeat{}}} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // errCh is used to collect any send errors from within the goroutine. + errCh := make(chan error) + + // Set up a goroutine to send the heartbeat every 1/2 of the timeout. + go func() { + // This is just a do while loop. We want to send the heartbeat right away to start + // because the test setup above takes some time and we might be close to the heartbeat + // timeout already. + for { + err := client.Send(heartbeatMsg) + if err != nil { + select { + case errCh <- err: + case <-ctx.Done(): + } + return + } + select { + case <-time.After(incomingHeartbeatTimeout / 2): + case <-ctx.Done(): + close(errCh) + return + } + } + }() + + // Assert that the stream remains connected for 5 heartbeat timeouts. + require.Never(t, func() bool { + status, ok := srv.StreamStatus(peerID) + if !ok { + return true + } + return !status.Connected + }, incomingHeartbeatTimeout*5, incomingHeartbeatTimeout) + + // Kill the heartbeat sending goroutine and check if it had any errors. + cancel() + err, ok := <-errCh + if ok { + require.NoError(t, err) + } +} + // makeClient sets up a *MockClient with the initial subscription // message handshake. func makeClient(t *testing.T, srv pbpeerstream.PeerStreamServiceServer, peerID string) *MockClient {