From 106d781dc99cb5f7f616d1de5bf1f1832942ec48 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Tue, 8 Sep 2020 15:22:35 -0400 Subject: [PATCH] subscribe: add initial impl from streaming-rpc-final branch Co-authored-by: Paul Banks --- agent/subscribe/subscribe.go | 187 +++++++++++++++++++++++++++++++++++ 1 file changed, 187 insertions(+) create mode 100644 agent/subscribe/subscribe.go diff --git a/agent/subscribe/subscribe.go b/agent/subscribe/subscribe.go new file mode 100644 index 0000000000..d2d7083889 --- /dev/null +++ b/agent/subscribe/subscribe.go @@ -0,0 +1,187 @@ +package subscribe + +import ( + "github.com/hashicorp/consul/agent/consul/stream" + "github.com/hashicorp/consul/proto/pbevent" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-uuid" +) + +// Server implements a StateChangeSubscriptionServer for accepting SubscribeRequests, +// and sending events to the subscription topic. +type Server struct { + srv *Server + logger hclog.Logger +} + +var _ pbevent.StateChangeSubscriptionServer = (*Server)(nil) + +func (h *Server) Subscribe(req *pbevent.SubscribeRequest, serverStream pbevent.StateChangeSubscription_SubscribeServer) error { + // streamID is just used for message correlation in trace logs and not + // populated normally. + var streamID string + var err error + + if h.logger.IsTrace() { + // TODO(banks) it might be nice one day to replace this with OpenTracing ID + // if one is set etc. but probably pointless until we support that properly + // in other places so it's actually propagated properly. For now this just + // makes lifetime of a stream more traceable in our regular server logs for + // debugging/dev. + streamID, err = uuid.GenerateUUID() + if err != nil { + return err + } + } + + // Forward the request to a remote DC if applicable. + if req.Datacenter != "" && req.Datacenter != h.srv.config.Datacenter { + return h.forwardAndProxy(req, serverStream, streamID) + } + + h.srv.logger.Trace("new subscription", + "topic", req.Topic.String(), + "key", req.Key, + "index", req.Index, + "stream_id", streamID, + ) + + var sentCount uint64 + defer h.srv.logger.Trace("subscription closed", "stream_id", streamID) + + // Resolve the token and create the ACL filter. + // TODO: handle token expiry gracefully... + authz, err := h.srv.ResolveToken(req.Token) + if err != nil { + return err + } + aclFilter := newACLFilter(authz, h.srv.logger, h.srv.config.ACLEnforceVersion8) + + state := h.srv.fsm.State() + + // Register a subscription on this topic/key with the FSM. + sub, err := state.Subscribe(serverStream.Context(), req) + if err != nil { + return err + } + defer state.Unsubscribe(req) + + // Deliver the events + for { + events, err := sub.Next() + if err == stream.ErrSubscriptionReload { + event := pbevent.Event{ + Payload: &pbevent.Event_ResetStream{ResetStream: true}, + } + if err := serverStream.Send(&event); err != nil { + return err + } + h.srv.logger.Trace("subscription reloaded", + "stream_id", streamID, + ) + return nil + } + if err != nil { + return err + } + + aclFilter.filterStreamEvents(&events) + + snapshotDone := false + if len(events) == 1 { + if events[0].GetEndOfSnapshot() { + snapshotDone = true + h.srv.logger.Trace("snapshot complete", + "index", events[0].Index, + "sent", sentCount, + "stream_id", streamID, + ) + } else if events[0].GetResumeStream() { + snapshotDone = true + h.srv.logger.Trace("resuming stream", + "index", events[0].Index, + "sent", sentCount, + "stream_id", streamID, + ) + } else if snapshotDone { + // Count this event too in the normal case as "sent" the above cases + // only show the number of events sent _before_ the snapshot ended. + h.srv.logger.Trace("sending events", + "index", events[0].Index, + "sent", sentCount, + "batch_size", 1, + "stream_id", streamID, + ) + } + sentCount++ + if err := serverStream.Send(&events[0]); err != nil { + return err + } + } else if len(events) > 1 { + e := &pbevent.Event{ + Topic: req.Topic, + Key: req.Key, + Index: events[0].Index, + Payload: &pbevent.Event_EventBatch{ + EventBatch: &pbevent.EventBatch{ + Events: pbevent.EventBatchEventsFromEventSlice(events), + }, + }, + } + sentCount += uint64(len(events)) + h.srv.logger.Trace("sending events", + "index", events[0].Index, + "sent", sentCount, + "batch_size", len(events), + "stream_id", streamID, + ) + if err := serverStream.Send(e); err != nil { + return err + } + } + } +} + +func (h *Server) forwardAndProxy( + req *pbevent.SubscribeRequest, + serverStream pbevent.StateChangeSubscription_SubscribeServer, + streamID string) error { + + conn, err := h.srv.grpcClient.GRPCConn(req.Datacenter) + if err != nil { + return err + } + + h.logger.Trace("forwarding to another DC", + "dc", req.Datacenter, + "topic", req.Topic.String(), + "key", req.Key, + "index", req.Index, + "stream_id", streamID, + ) + + defer func() { + h.logger.Trace("forwarded stream closed", + "dc", req.Datacenter, + "stream_id", streamID, + ) + }() + + // Open a Subscribe call to the remote DC. + client := pbevent.NewConsulClient(conn) + streamHandle, err := client.Subscribe(serverStream.Context(), req) + if err != nil { + return err + } + + // Relay the events back to the client. + for { + event, err := streamHandle.Recv() + if err != nil { + return err + } + if err := serverStream.Send(event); err != nil { + return err + } + } +}