diff --git a/agent/consul/state/store_integration_test.go b/agent/consul/state/store_integration_test.go index 7f2ae62ce4..a16c635e1d 100644 --- a/agent/consul/state/store_integration_test.go +++ b/agent/consul/state/store_integration_test.go @@ -64,7 +64,7 @@ func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) { // Ensure the reset event was sent. err = assertErr(t, eventCh) - require.Equal(stream.ErrSubscriptionClosed, err) + require.Equal(stream.ErrSubForceClosed, err) // Register another subscription. subscription2 := &stream.SubscribeRequest{ @@ -93,7 +93,7 @@ func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) { // Ensure the reset event was sent. err = assertErr(t, eventCh2) - require.Equal(stream.ErrSubscriptionClosed, err) + require.Equal(stream.ErrSubForceClosed, err) } func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) { @@ -162,6 +162,7 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) { } sub, err = publisher.Subscribe(subscription2) require.NoError(err) + defer sub.Unsubscribe() eventCh = testRunSub(sub) @@ -180,7 +181,7 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) { // Ensure the reload event was sent. err = assertErr(t, eventCh) - require.Equal(stream.ErrSubscriptionClosed, err) + require.Equal(stream.ErrSubForceClosed, err) // Register another subscription. subscription3 := &stream.SubscribeRequest{ @@ -367,7 +368,7 @@ func assertReset(t *testing.T, eventCh <-chan nextResult, allowEOS bool) { } } require.Error(t, next.Err) - require.Equal(t, stream.ErrSubscriptionClosed, next.Err) + require.Equal(t, stream.ErrSubForceClosed, next.Err) return case <-time.After(100 * time.Millisecond): t.Fatalf("no err after 100ms") diff --git a/agent/consul/stream/event_buffer.go b/agent/consul/stream/event_buffer.go index 1208b07b14..eca2dbec10 100644 --- a/agent/consul/stream/event_buffer.go +++ b/agent/consul/stream/event_buffer.go @@ -170,13 +170,13 @@ func newBufferItem(events []Event) *bufferItem { // Next return the next buffer item in the buffer. It may block until ctx is // cancelled or until the next item is published. -func (i *bufferItem) Next(ctx context.Context, forceClose <-chan struct{}) (*bufferItem, error) { +func (i *bufferItem) Next(ctx context.Context, closed <-chan struct{}) (*bufferItem, error) { // See if there is already a next value, block if so. Note we don't rely on // state change (chan nil) as that's not threadsafe but detecting close is. select { case <-ctx.Done(): return nil, ctx.Err() - case <-forceClose: + case <-closed: return nil, fmt.Errorf("subscription closed") case <-i.link.ch: } diff --git a/agent/consul/stream/event_publisher_test.go b/agent/consul/stream/event_publisher_test.go index 940a908c97..0dec574960 100644 --- a/agent/consul/stream/event_publisher_test.go +++ b/agent/consul/stream/event_publisher_test.go @@ -30,6 +30,7 @@ func TestEventPublisher_SubscribeWithIndex0(t *testing.T) { sub, err := publisher.Subscribe(req) require.NoError(t, err) + defer sub.Unsubscribe() eventCh := runSubscription(ctx, sub) next := getNextEvent(t, eventCh) @@ -141,10 +142,10 @@ func TestEventPublisher_ShutdownClosesSubscriptions(t *testing.T) { cancel() // Shutdown err = consumeSub(context.Background(), sub1) - require.Equal(t, err, ErrSubscriptionClosed) + require.Equal(t, err, ErrSubForceClosed) _, err = sub2.Next(context.Background()) - require.Equal(t, err, ErrSubscriptionClosed) + require.Equal(t, err, ErrSubForceClosed) } func consumeSub(ctx context.Context, sub *Subscription) error { @@ -169,14 +170,15 @@ func TestEventPublisher_SubscribeWithIndex0_FromCache(t *testing.T) { publisher := NewEventPublisher(newTestSnapshotHandlers(), time.Second) go publisher.Run(ctx) - _, err := publisher.Subscribe(req) + sub, err := publisher.Subscribe(req) require.NoError(t, err) + sub.Unsubscribe() publisher.snapshotHandlers[testTopic] = func(_ SubscribeRequest, _ SnapshotAppender) (uint64, error) { return 0, fmt.Errorf("error should not be seen, cache should have been used") } - sub, err := publisher.Subscribe(req) + sub, err = publisher.Subscribe(req) require.NoError(t, err) eventCh := runSubscription(ctx, sub) @@ -357,6 +359,7 @@ func TestEventPublisher_SubscribeWithIndexNotZero_NewSnapshotFromCache(t *testin newReq.Index = 1 sub, err := publisher.Subscribe(&newReq) require.NoError(t, err) + defer sub.Unsubscribe() eventCh := runSubscription(ctx, sub) next := getNextEvent(t, eventCh) @@ -379,3 +382,25 @@ func runStep(t *testing.T, name string, fn func(t *testing.T)) { t.FailNow() } } + +func TestEventPublisher_Unsubscribe_ClosesSubscription(t *testing.T) { + req := &SubscribeRequest{ + Topic: testTopic, + Key: "sub-key", + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + publisher := NewEventPublisher(newTestSnapshotHandlers(), time.Second) + + sub, err := publisher.Subscribe(req) + require.NoError(t, err) + + _, err = sub.Next(ctx) + require.NoError(t, err) + + sub.Unsubscribe() + _, err = sub.Next(ctx) + require.Error(t, err) + require.Contains(t, err.Error(), "subscription was closed by unsubscribe") +} diff --git a/agent/consul/stream/subscription.go b/agent/consul/stream/subscription.go index a602cad551..012b410928 100644 --- a/agent/consul/stream/subscription.go +++ b/agent/consul/stream/subscription.go @@ -3,28 +3,33 @@ package stream import ( "context" "errors" + "fmt" "sync/atomic" ) const ( - // subscriptionStateOpen is the default state of a subscription. An open - // subscription may receive new events. - subscriptionStateOpen uint32 = 0 + // subStateOpen is the default state of a subscription. An open subscription + // may return new events. + subStateOpen = 0 - // subscriptionStateClosed indicates that the subscription was closed, possibly - // as a result of a change to an ACL token, and will not receive new events. + // subStateForceClosed indicates the subscription was forced closed by + // the EventPublisher, possibly as a result of a change to an ACL token, and + // will not return new events. // The subscriber must issue a new Subscribe request. - subscriptionStateClosed uint32 = 1 + subStateForceClosed = 1 + + // subStateUnsub indicates the subscription was closed by the caller, and + // will not return new events. + subStateUnsub = 2 ) -// ErrSubscriptionClosed is a error signalling the subscription has been +// ErrSubForceClosed is a error signalling the subscription has been // closed. The client should Unsubscribe, then re-Subscribe. -var ErrSubscriptionClosed = errors.New("subscription closed by server, client must reset state and resubscribe") +var ErrSubForceClosed = errors.New("subscription closed by server, client must reset state and resubscribe") // Subscription provides events on a Topic. Events may be filtered by Key. // Events are returned by Next(), and may start with a Snapshot of events. type Subscription struct { - // state is accessed atomically 0 means open, 1 means closed with reload state uint32 // req is the requests that we are responding to @@ -34,9 +39,9 @@ type Subscription struct { // is mutated by calls to Next. currentItem *bufferItem - // forceClosed is closed when forceClose is called. It is used by - // EventPublisher to cancel Next(). - forceClosed chan struct{} + // closed is a channel which is closed when the subscription is closed. It + // is used to exit the blocking select. + closed chan struct{} // unsub is a function set by EventPublisher that is called to free resources // when the subscription is no longer needed. @@ -58,7 +63,7 @@ type SubscribeRequest struct { // calling Unsubscribe when it is done with the subscription, to free resources. func newSubscription(req SubscribeRequest, item *bufferItem, unsub func()) *Subscription { return &Subscription{ - forceClosed: make(chan struct{}), + closed: make(chan struct{}), req: req, currentItem: item, unsub: unsub, @@ -68,16 +73,16 @@ func newSubscription(req SubscribeRequest, item *bufferItem, unsub func()) *Subs // Next returns the next Event to deliver. It must only be called from a // single goroutine concurrently as it mutates the Subscription. func (s *Subscription) Next(ctx context.Context) (Event, error) { - if atomic.LoadUint32(&s.state) == subscriptionStateClosed { - return Event{}, ErrSubscriptionClosed - } - for { - next, err := s.currentItem.Next(ctx, s.forceClosed) - switch { - case err != nil && atomic.LoadUint32(&s.state) == subscriptionStateClosed: - return Event{}, ErrSubscriptionClosed - case err != nil: + if err := s.requireStateOpen(); err != nil { + return Event{}, err + } + + next, err := s.currentItem.Next(ctx, s.closed) + if err := s.requireStateOpen(); err != nil { + return Event{}, err + } + if err != nil { return Event{}, err } s.currentItem = next @@ -92,6 +97,17 @@ func (s *Subscription) Next(ctx context.Context) (Event, error) { } } +func (s *Subscription) requireStateOpen() error { + switch atomic.LoadUint32(&s.state) { + case subStateForceClosed: + return ErrSubForceClosed + case subStateUnsub: + return fmt.Errorf("subscription was closed by unsubscribe") + default: + return nil + } +} + func newEventFromBatch(req SubscribeRequest, events []Event) Event { first := events[0] if len(events) == 1 { @@ -121,13 +137,15 @@ func filterByKey(req SubscribeRequest, events []Event) (Event, bool) { // and will need to perform a new Subscribe request. // It is safe to call from any goroutine. func (s *Subscription) forceClose() { - swapped := atomic.CompareAndSwapUint32(&s.state, subscriptionStateOpen, subscriptionStateClosed) - if swapped { - close(s.forceClosed) + if atomic.CompareAndSwapUint32(&s.state, subStateOpen, subStateForceClosed) { + close(s.closed) } } // Unsubscribe the subscription, freeing resources. func (s *Subscription) Unsubscribe() { + if atomic.CompareAndSwapUint32(&s.state, subStateOpen, subStateUnsub) { + close(s.closed) + } s.unsub() } diff --git a/agent/consul/stream/subscription_test.go b/agent/consul/stream/subscription_test.go index a2f6fb106d..db15313f57 100644 --- a/agent/consul/stream/subscription_test.go +++ b/agent/consul/stream/subscription_test.go @@ -122,7 +122,7 @@ func TestSubscription_Close(t *testing.T) { _, err = sub.Next(ctx) elapsed = time.Since(start) require.Error(t, err) - require.Equal(t, ErrSubscriptionClosed, err) + require.Equal(t, ErrSubForceClosed, err) require.True(t, elapsed > 200*time.Millisecond, "Reload should have happened after blocking 200ms, took %s", elapsed) require.True(t, elapsed < 2*time.Second, diff --git a/agent/rpc/subscribe/subscribe.go b/agent/rpc/subscribe/subscribe.go index 934819e2ec..b7eda488e4 100644 --- a/agent/rpc/subscribe/subscribe.go +++ b/agent/rpc/subscribe/subscribe.go @@ -69,7 +69,7 @@ func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsub for { event, err := sub.Next(ctx) switch { - case errors.Is(err, stream.ErrSubscriptionClosed): + case errors.Is(err, stream.ErrSubForceClosed): logger.Trace("subscription reset by server") return status.Error(codes.Aborted, err.Error()) case err != nil: