stream: close the subscription on Unsubscribe

pull/8975/head
Daniel Nephin 2020-10-15 18:06:04 -04:00
parent a3f8aa20dd
commit fb57d9b26a
6 changed files with 81 additions and 37 deletions

View File

@ -64,7 +64,7 @@ func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) {
// Ensure the reset event was sent. // Ensure the reset event was sent.
err = assertErr(t, eventCh) err = assertErr(t, eventCh)
require.Equal(stream.ErrSubscriptionClosed, err) require.Equal(stream.ErrSubForceClosed, err)
// Register another subscription. // Register another subscription.
subscription2 := &stream.SubscribeRequest{ subscription2 := &stream.SubscribeRequest{
@ -93,7 +93,7 @@ func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) {
// Ensure the reset event was sent. // Ensure the reset event was sent.
err = assertErr(t, eventCh2) err = assertErr(t, eventCh2)
require.Equal(stream.ErrSubscriptionClosed, err) require.Equal(stream.ErrSubForceClosed, err)
} }
func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) { func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
@ -162,6 +162,7 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
} }
sub, err = publisher.Subscribe(subscription2) sub, err = publisher.Subscribe(subscription2)
require.NoError(err) require.NoError(err)
defer sub.Unsubscribe()
eventCh = testRunSub(sub) eventCh = testRunSub(sub)
@ -180,7 +181,7 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
// Ensure the reload event was sent. // Ensure the reload event was sent.
err = assertErr(t, eventCh) err = assertErr(t, eventCh)
require.Equal(stream.ErrSubscriptionClosed, err) require.Equal(stream.ErrSubForceClosed, err)
// Register another subscription. // Register another subscription.
subscription3 := &stream.SubscribeRequest{ subscription3 := &stream.SubscribeRequest{
@ -367,7 +368,7 @@ func assertReset(t *testing.T, eventCh <-chan nextResult, allowEOS bool) {
} }
} }
require.Error(t, next.Err) require.Error(t, next.Err)
require.Equal(t, stream.ErrSubscriptionClosed, next.Err) require.Equal(t, stream.ErrSubForceClosed, next.Err)
return return
case <-time.After(100 * time.Millisecond): case <-time.After(100 * time.Millisecond):
t.Fatalf("no err after 100ms") t.Fatalf("no err after 100ms")

View File

@ -170,13 +170,13 @@ func newBufferItem(events []Event) *bufferItem {
// Next return the next buffer item in the buffer. It may block until ctx is // Next return the next buffer item in the buffer. It may block until ctx is
// cancelled or until the next item is published. // cancelled or until the next item is published.
func (i *bufferItem) Next(ctx context.Context, forceClose <-chan struct{}) (*bufferItem, error) { func (i *bufferItem) Next(ctx context.Context, closed <-chan struct{}) (*bufferItem, error) {
// See if there is already a next value, block if so. Note we don't rely on // See if there is already a next value, block if so. Note we don't rely on
// state change (chan nil) as that's not threadsafe but detecting close is. // state change (chan nil) as that's not threadsafe but detecting close is.
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return nil, ctx.Err()
case <-forceClose: case <-closed:
return nil, fmt.Errorf("subscription closed") return nil, fmt.Errorf("subscription closed")
case <-i.link.ch: case <-i.link.ch:
} }

View File

@ -30,6 +30,7 @@ func TestEventPublisher_SubscribeWithIndex0(t *testing.T) {
sub, err := publisher.Subscribe(req) sub, err := publisher.Subscribe(req)
require.NoError(t, err) require.NoError(t, err)
defer sub.Unsubscribe()
eventCh := runSubscription(ctx, sub) eventCh := runSubscription(ctx, sub)
next := getNextEvent(t, eventCh) next := getNextEvent(t, eventCh)
@ -141,10 +142,10 @@ func TestEventPublisher_ShutdownClosesSubscriptions(t *testing.T) {
cancel() // Shutdown cancel() // Shutdown
err = consumeSub(context.Background(), sub1) err = consumeSub(context.Background(), sub1)
require.Equal(t, err, ErrSubscriptionClosed) require.Equal(t, err, ErrSubForceClosed)
_, err = sub2.Next(context.Background()) _, err = sub2.Next(context.Background())
require.Equal(t, err, ErrSubscriptionClosed) require.Equal(t, err, ErrSubForceClosed)
} }
func consumeSub(ctx context.Context, sub *Subscription) error { func consumeSub(ctx context.Context, sub *Subscription) error {
@ -169,14 +170,15 @@ func TestEventPublisher_SubscribeWithIndex0_FromCache(t *testing.T) {
publisher := NewEventPublisher(newTestSnapshotHandlers(), time.Second) publisher := NewEventPublisher(newTestSnapshotHandlers(), time.Second)
go publisher.Run(ctx) go publisher.Run(ctx)
_, err := publisher.Subscribe(req) sub, err := publisher.Subscribe(req)
require.NoError(t, err) require.NoError(t, err)
sub.Unsubscribe()
publisher.snapshotHandlers[testTopic] = func(_ SubscribeRequest, _ SnapshotAppender) (uint64, error) { publisher.snapshotHandlers[testTopic] = func(_ SubscribeRequest, _ SnapshotAppender) (uint64, error) {
return 0, fmt.Errorf("error should not be seen, cache should have been used") return 0, fmt.Errorf("error should not be seen, cache should have been used")
} }
sub, err := publisher.Subscribe(req) sub, err = publisher.Subscribe(req)
require.NoError(t, err) require.NoError(t, err)
eventCh := runSubscription(ctx, sub) eventCh := runSubscription(ctx, sub)
@ -357,6 +359,7 @@ func TestEventPublisher_SubscribeWithIndexNotZero_NewSnapshotFromCache(t *testin
newReq.Index = 1 newReq.Index = 1
sub, err := publisher.Subscribe(&newReq) sub, err := publisher.Subscribe(&newReq)
require.NoError(t, err) require.NoError(t, err)
defer sub.Unsubscribe()
eventCh := runSubscription(ctx, sub) eventCh := runSubscription(ctx, sub)
next := getNextEvent(t, eventCh) next := getNextEvent(t, eventCh)
@ -379,3 +382,25 @@ func runStep(t *testing.T, name string, fn func(t *testing.T)) {
t.FailNow() t.FailNow()
} }
} }
func TestEventPublisher_Unsubscribe_ClosesSubscription(t *testing.T) {
req := &SubscribeRequest{
Topic: testTopic,
Key: "sub-key",
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
publisher := NewEventPublisher(newTestSnapshotHandlers(), time.Second)
sub, err := publisher.Subscribe(req)
require.NoError(t, err)
_, err = sub.Next(ctx)
require.NoError(t, err)
sub.Unsubscribe()
_, err = sub.Next(ctx)
require.Error(t, err)
require.Contains(t, err.Error(), "subscription was closed by unsubscribe")
}

View File

@ -3,28 +3,33 @@ package stream
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"sync/atomic" "sync/atomic"
) )
const ( const (
// subscriptionStateOpen is the default state of a subscription. An open // subStateOpen is the default state of a subscription. An open subscription
// subscription may receive new events. // may return new events.
subscriptionStateOpen uint32 = 0 subStateOpen = 0
// subscriptionStateClosed indicates that the subscription was closed, possibly // subStateForceClosed indicates the subscription was forced closed by
// as a result of a change to an ACL token, and will not receive new events. // the EventPublisher, possibly as a result of a change to an ACL token, and
// will not return new events.
// The subscriber must issue a new Subscribe request. // The subscriber must issue a new Subscribe request.
subscriptionStateClosed uint32 = 1 subStateForceClosed = 1
// subStateUnsub indicates the subscription was closed by the caller, and
// will not return new events.
subStateUnsub = 2
) )
// ErrSubscriptionClosed is a error signalling the subscription has been // ErrSubForceClosed is a error signalling the subscription has been
// closed. The client should Unsubscribe, then re-Subscribe. // closed. The client should Unsubscribe, then re-Subscribe.
var ErrSubscriptionClosed = errors.New("subscription closed by server, client must reset state and resubscribe") var ErrSubForceClosed = errors.New("subscription closed by server, client must reset state and resubscribe")
// Subscription provides events on a Topic. Events may be filtered by Key. // Subscription provides events on a Topic. Events may be filtered by Key.
// Events are returned by Next(), and may start with a Snapshot of events. // Events are returned by Next(), and may start with a Snapshot of events.
type Subscription struct { type Subscription struct {
// state is accessed atomically 0 means open, 1 means closed with reload
state uint32 state uint32
// req is the requests that we are responding to // req is the requests that we are responding to
@ -34,9 +39,9 @@ type Subscription struct {
// is mutated by calls to Next. // is mutated by calls to Next.
currentItem *bufferItem currentItem *bufferItem
// forceClosed is closed when forceClose is called. It is used by // closed is a channel which is closed when the subscription is closed. It
// EventPublisher to cancel Next(). // is used to exit the blocking select.
forceClosed chan struct{} closed chan struct{}
// unsub is a function set by EventPublisher that is called to free resources // unsub is a function set by EventPublisher that is called to free resources
// when the subscription is no longer needed. // when the subscription is no longer needed.
@ -58,7 +63,7 @@ type SubscribeRequest struct {
// calling Unsubscribe when it is done with the subscription, to free resources. // calling Unsubscribe when it is done with the subscription, to free resources.
func newSubscription(req SubscribeRequest, item *bufferItem, unsub func()) *Subscription { func newSubscription(req SubscribeRequest, item *bufferItem, unsub func()) *Subscription {
return &Subscription{ return &Subscription{
forceClosed: make(chan struct{}), closed: make(chan struct{}),
req: req, req: req,
currentItem: item, currentItem: item,
unsub: unsub, unsub: unsub,
@ -68,16 +73,16 @@ func newSubscription(req SubscribeRequest, item *bufferItem, unsub func()) *Subs
// Next returns the next Event to deliver. It must only be called from a // Next returns the next Event to deliver. It must only be called from a
// single goroutine concurrently as it mutates the Subscription. // single goroutine concurrently as it mutates the Subscription.
func (s *Subscription) Next(ctx context.Context) (Event, error) { func (s *Subscription) Next(ctx context.Context) (Event, error) {
if atomic.LoadUint32(&s.state) == subscriptionStateClosed {
return Event{}, ErrSubscriptionClosed
}
for { for {
next, err := s.currentItem.Next(ctx, s.forceClosed) if err := s.requireStateOpen(); err != nil {
switch { return Event{}, err
case err != nil && atomic.LoadUint32(&s.state) == subscriptionStateClosed: }
return Event{}, ErrSubscriptionClosed
case err != nil: next, err := s.currentItem.Next(ctx, s.closed)
if err := s.requireStateOpen(); err != nil {
return Event{}, err
}
if err != nil {
return Event{}, err return Event{}, err
} }
s.currentItem = next s.currentItem = next
@ -92,6 +97,17 @@ func (s *Subscription) Next(ctx context.Context) (Event, error) {
} }
} }
func (s *Subscription) requireStateOpen() error {
switch atomic.LoadUint32(&s.state) {
case subStateForceClosed:
return ErrSubForceClosed
case subStateUnsub:
return fmt.Errorf("subscription was closed by unsubscribe")
default:
return nil
}
}
func newEventFromBatch(req SubscribeRequest, events []Event) Event { func newEventFromBatch(req SubscribeRequest, events []Event) Event {
first := events[0] first := events[0]
if len(events) == 1 { if len(events) == 1 {
@ -121,13 +137,15 @@ func filterByKey(req SubscribeRequest, events []Event) (Event, bool) {
// and will need to perform a new Subscribe request. // and will need to perform a new Subscribe request.
// It is safe to call from any goroutine. // It is safe to call from any goroutine.
func (s *Subscription) forceClose() { func (s *Subscription) forceClose() {
swapped := atomic.CompareAndSwapUint32(&s.state, subscriptionStateOpen, subscriptionStateClosed) if atomic.CompareAndSwapUint32(&s.state, subStateOpen, subStateForceClosed) {
if swapped { close(s.closed)
close(s.forceClosed)
} }
} }
// Unsubscribe the subscription, freeing resources. // Unsubscribe the subscription, freeing resources.
func (s *Subscription) Unsubscribe() { func (s *Subscription) Unsubscribe() {
if atomic.CompareAndSwapUint32(&s.state, subStateOpen, subStateUnsub) {
close(s.closed)
}
s.unsub() s.unsub()
} }

View File

@ -122,7 +122,7 @@ func TestSubscription_Close(t *testing.T) {
_, err = sub.Next(ctx) _, err = sub.Next(ctx)
elapsed = time.Since(start) elapsed = time.Since(start)
require.Error(t, err) require.Error(t, err)
require.Equal(t, ErrSubscriptionClosed, err) require.Equal(t, ErrSubForceClosed, err)
require.True(t, elapsed > 200*time.Millisecond, require.True(t, elapsed > 200*time.Millisecond,
"Reload should have happened after blocking 200ms, took %s", elapsed) "Reload should have happened after blocking 200ms, took %s", elapsed)
require.True(t, elapsed < 2*time.Second, require.True(t, elapsed < 2*time.Second,

View File

@ -69,7 +69,7 @@ func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsub
for { for {
event, err := sub.Next(ctx) event, err := sub.Next(ctx)
switch { switch {
case errors.Is(err, stream.ErrSubscriptionClosed): case errors.Is(err, stream.ErrSubForceClosed):
logger.Trace("subscription reset by server") logger.Trace("subscription reset by server")
return status.Error(codes.Aborted, err.Error()) return status.Error(codes.Aborted, err.Error())
case err != nil: case err != nil: