Add send mutex to protect against concurrent sends (#13805)

pull/13806/head
Luke Kysow 2 years ago committed by GitHub
parent 281892ab7c
commit c411e6b326
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,6 +5,7 @@ import (
"fmt"
"io"
"strings"
"sync"
"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto"
@ -204,6 +205,25 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error {
)
subCh := mgr.subscribe(streamReq.Stream.Context(), streamReq.LocalID, streamReq.PeerName, streamReq.Partition)
// We need a mutex to protect against simultaneous sends to the client.
var sendMutex sync.Mutex
// streamSend is a helper function that sends msg over the stream
// respecting the send mutex. It also logs the send and calls status.TrackSendError
// on error.
streamSend := func(msg *pbpeerstream.ReplicationMessage) error {
logTraceSend(logger, msg)
sendMutex.Lock()
err := streamReq.Stream.Send(msg)
sendMutex.Unlock()
if err != nil {
status.TrackSendError(err.Error())
}
return err
}
// Subscribe to all relevant resource types.
for _, resourceURL := range []string{
pbpeerstream.TypeURLExportedService,
@ -213,16 +233,12 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error {
ResourceURL: resourceURL,
PeerID: streamReq.RemoteID,
})
logTraceSend(logger, sub)
if err := streamReq.Stream.Send(sub); err != nil {
if err := streamSend(sub); err != nil {
if err == io.EOF {
logger.Info("stream ended by peer")
status.TrackReceiveError(err.Error())
return nil
}
// TODO(peering) Test error handling in calls to Send/Recv
status.TrackSendError(err.Error())
return fmt.Errorf("failed to send subscription for %q to stream: %w", resourceURL, err)
}
}
@ -261,10 +277,7 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error {
Terminated: &pbpeerstream.ReplicationMessage_Terminated{},
},
}
logTraceSend(logger, term)
if err := streamReq.Stream.Send(term); err != nil {
status.TrackSendError(err.Error())
if err := streamSend(term); err != nil {
return fmt.Errorf("failed to send to stream: %v", err)
}
@ -401,9 +414,7 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error {
status.TrackReceiveSuccess()
}
logTraceSend(logger, reply)
if err := streamReq.Stream.Send(reply); err != nil {
status.TrackSendError(err.Error())
if err := streamSend(reply); err != nil {
return fmt.Errorf("failed to send to stream: %v", err)
}
@ -451,10 +462,7 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error {
}
replResp := makeReplicationResponse(resp)
logTraceSend(logger, replResp)
if err := streamReq.Stream.Send(replResp); err != nil {
status.TrackSendError(err.Error())
if err := streamSend(replResp); err != nil {
return fmt.Errorf("failed to push data for %q: %w", update.CorrelationID, err)
}
}

Loading…
Cancel
Save