EventPublisher: Make Unsubscribe a function on Subscription

It is critical that Unsubscribe be called with the same pointer to a
SubscriptionRequest that was used to create the Subscription. The
docstring made that clear, but it sill allowed a caler to get it wrong by
creating a new SubscriptionRequest.

By hiding this detail from the caller, and only exposing an Unsubscribe
method, it should be impossible to fail to Unsubscribe.

Also update some godoc strings.
pull/8160/head
Daniel Nephin 2020-06-18 18:29:06 -04:00
parent 1622bb3a45
commit 2c8342f115
3 changed files with 39 additions and 31 deletions

View File

@ -51,12 +51,16 @@ type EventPublisher struct {
} }
type subscriptions struct { type subscriptions struct {
// lock for byToken. If both subscription.lock and EventPublisher.lock need
// to be held, EventPublisher.lock MUST always be acquired first.
lock sync.RWMutex lock sync.RWMutex
// subsByToken stores a list of Subscription objects outstanding indexed by a // byToken is an mapping of active Subscriptions indexed by a the token and
// hash of the ACL token they used to subscribe so we can reload them if their // a pointer to the request.
// ACL permissions change. // When the token is modified all subscriptions under that token will be
subsByToken map[string]map[*stream.SubscribeRequest]*stream.Subscription // reloaded.
// A subscription may be unsubscribed by using the pointer to the request.
byToken map[string]map[*stream.SubscribeRequest]*stream.Subscription
} }
type commitUpdate struct { type commitUpdate struct {
@ -70,7 +74,7 @@ func NewEventPublisher(handlers map[stream.Topic]topicHandler, snapCacheTTL time
snapCache: make(map[stream.Topic]map[string]*stream.EventSnapshot), snapCache: make(map[stream.Topic]map[string]*stream.EventSnapshot),
publishCh: make(chan commitUpdate, 64), publishCh: make(chan commitUpdate, 64),
subscriptions: &subscriptions{ subscriptions: &subscriptions{
subsByToken: make(map[string]map[*stream.SubscribeRequest]*stream.Subscription), byToken: make(map[string]map[*stream.SubscribeRequest]*stream.Subscription),
}, },
handlers: handlers, handlers: handlers,
} }
@ -160,8 +164,8 @@ func (s *subscriptions) handleACLUpdate(tx ReadTxn, event stream.Event) error {
switch event.Topic { switch event.Topic {
case stream.Topic_ACLTokens: case stream.Topic_ACLTokens:
token := event.Payload.(*structs.ACLToken) token := event.Payload.(*structs.ACLToken)
for _, sub := range s.subsByToken[token.SecretID] { for _, sub := range s.byToken[token.SecretID] {
sub.CloseReload() sub.ForceReload()
} }
case stream.Topic_ACLPolicies: case stream.Topic_ACLPolicies:
@ -199,13 +203,13 @@ func (s *subscriptions) handleACLUpdate(tx ReadTxn, event stream.Event) error {
return nil return nil
} }
// This method requires the EventPublisher.lock is held // This method requires the subscriptions.lock.RLock is held (the read-only lock)
func (s *subscriptions) closeSubscriptionsForTokens(tokens memdb.ResultIterator) { func (s *subscriptions) closeSubscriptionsForTokens(tokens memdb.ResultIterator) {
for token := tokens.Next(); token != nil; token = tokens.Next() { for token := tokens.Next(); token != nil; token = tokens.Next() {
token := token.(*structs.ACLToken) token := token.(*structs.ACLToken)
if subs, ok := s.subsByToken[token.SecretID]; ok { if subs, ok := s.byToken[token.SecretID]; ok {
for _, sub := range subs { for _, sub := range subs {
sub.CloseReload() sub.ForceReload()
} }
} }
} }
@ -218,8 +222,8 @@ func (s *subscriptions) closeSubscriptionsForTokens(tokens memdb.ResultIterator)
// decides it can no longer maintain correct operation for example if ACL // decides it can no longer maintain correct operation for example if ACL
// policies changed or the state store was restored. // policies changed or the state store was restored.
// //
// When the called is finished with the subscription for any reason, it must // When the caller is finished with the subscription for any reason, it must
// call Unsubscribe to free ACL tracking resources. // call Subscription.Unsubscribe to free ACL tracking resources.
func (e *EventPublisher) Subscribe( func (e *EventPublisher) Subscribe(
ctx context.Context, ctx context.Context,
req *stream.SubscribeRequest, req *stream.SubscribeRequest,
@ -278,7 +282,12 @@ func (e *EventPublisher) Subscribe(
} }
e.subscriptions.add(req, sub) e.subscriptions.add(req, sub)
// Set unsubscribe so that the caller doesn't need to keep track of the
// SubscriptionRequest, and can not accidentally call unsubscribe with the
// wrong value.
sub.Unsubscribe = func() {
e.subscriptions.unsubscribe(req)
}
return sub, nil return sub, nil
} }
@ -286,28 +295,29 @@ func (s *subscriptions) add(req *stream.SubscribeRequest, sub *stream.Subscripti
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
subsByToken, ok := s.subsByToken[req.Token] subsByToken, ok := s.byToken[req.Token]
if !ok { if !ok {
subsByToken = make(map[*stream.SubscribeRequest]*stream.Subscription) subsByToken = make(map[*stream.SubscribeRequest]*stream.Subscription)
s.subsByToken[req.Token] = subsByToken s.byToken[req.Token] = subsByToken
} }
subsByToken[req] = sub subsByToken[req] = sub
} }
// Unsubscribe must be called when a client is no longer interested in a // unsubscribe must be called when a client is no longer interested in a
// subscription to free resources monitoring changes in it's ACL token. The same // subscription to free resources monitoring changes in it's ACL token.
// request object passed to Subscribe must be used. //
func (s *subscriptions) Unsubscribe(req *stream.SubscribeRequest) { // req MUST be the same pointer that was used to register the subscription.
func (s *subscriptions) unsubscribe(req *stream.SubscribeRequest) {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
subsByToken, ok := s.subsByToken[req.Token] subsByToken, ok := s.byToken[req.Token]
if !ok { if !ok {
return return
} }
delete(subsByToken, req) delete(subsByToken, req)
if len(subsByToken) == 0 { if len(subsByToken) == 0 {
delete(s.subsByToken, req.Token) delete(s.byToken, req.Token)
} }
} }

View File

@ -42,6 +42,10 @@ type Subscription struct {
// cancelFn stores the context cancel function that will wake up the // cancelFn stores the context cancel function that will wake up the
// in-progress Next call on a server-initiated state change e.g. Reload. // in-progress Next call on a server-initiated state change e.g. Reload.
cancelFn func() cancelFn func()
// Unsubscribe is a function set by EventPublisher that is called to
// free resources when the subscription is no longer needed.
Unsubscribe func()
} }
type SubscribeRequest struct { type SubscribeRequest struct {
@ -116,9 +120,9 @@ func (s *Subscription) Next() ([]Event, error) {
} }
} }
// CloseReload closes the stream and signals that the subscriber should reload. // ForceReload closes the stream and signals that the subscriber should reload.
// It is safe to call from any goroutine. // It is safe to call from any goroutine.
func (s *Subscription) CloseReload() { func (s *Subscription) ForceReload() {
swapped := atomic.CompareAndSwapUint32(&s.state, SubscriptionStateOpen, swapped := atomic.CompareAndSwapUint32(&s.state, SubscriptionStateOpen,
SubscriptionStateCloseReload) SubscriptionStateCloseReload)
@ -126,9 +130,3 @@ func (s *Subscription) CloseReload() {
s.cancelFn() s.cancelFn()
} }
} }
// Request returns the request object that started the subscription.
// TODO: remove
func (s *Subscription) Request() *SubscribeRequest {
return s.req
}

View File

@ -118,11 +118,11 @@ func TestSubscriptionCloseReload(t *testing.T) {
require.Len(t, got, 1) require.Len(t, got, 1)
require.Equal(t, index, got[0].Index) require.Equal(t, index, got[0].Index)
// Schedule a CloseReload simulating the server deciding this subscroption // Schedule a ForceReload simulating the server deciding this subscroption
// needs to reset (e.g. on ACL perm change). // needs to reset (e.g. on ACL perm change).
start = time.Now() start = time.Now()
time.AfterFunc(200*time.Millisecond, func() { time.AfterFunc(200*time.Millisecond, func() {
sub.CloseReload() sub.ForceReload()
}) })
_, err = sub.Next() _, err = sub.Next()