diff --git a/agent/consul/state/event_publisher.go b/agent/consul/state/event_publisher.go index 800fa8a78d..d4a97768fb 100644 --- a/agent/consul/state/event_publisher.go +++ b/agent/consul/state/event_publisher.go @@ -51,12 +51,16 @@ type EventPublisher struct { } type subscriptions struct { + // lock for byToken. If both subscription.lock and EventPublisher.lock need + // to be held, EventPublisher.lock MUST always be acquired first. lock sync.RWMutex - // subsByToken stores a list of Subscription objects outstanding indexed by a - // hash of the ACL token they used to subscribe so we can reload them if their - // ACL permissions change. - subsByToken map[string]map[*stream.SubscribeRequest]*stream.Subscription + // byToken is an mapping of active Subscriptions indexed by a the token and + // a pointer to the request. + // When the token is modified all subscriptions under that token will be + // reloaded. + // A subscription may be unsubscribed by using the pointer to the request. + byToken map[string]map[*stream.SubscribeRequest]*stream.Subscription } type commitUpdate struct { @@ -70,7 +74,7 @@ func NewEventPublisher(handlers map[stream.Topic]topicHandler, snapCacheTTL time snapCache: make(map[stream.Topic]map[string]*stream.EventSnapshot), publishCh: make(chan commitUpdate, 64), subscriptions: &subscriptions{ - subsByToken: make(map[string]map[*stream.SubscribeRequest]*stream.Subscription), + byToken: make(map[string]map[*stream.SubscribeRequest]*stream.Subscription), }, handlers: handlers, } @@ -160,8 +164,8 @@ func (s *subscriptions) handleACLUpdate(tx ReadTxn, event stream.Event) error { switch event.Topic { case stream.Topic_ACLTokens: token := event.Payload.(*structs.ACLToken) - for _, sub := range s.subsByToken[token.SecretID] { - sub.CloseReload() + for _, sub := range s.byToken[token.SecretID] { + sub.ForceReload() } case stream.Topic_ACLPolicies: @@ -199,13 +203,13 @@ func (s *subscriptions) handleACLUpdate(tx ReadTxn, event stream.Event) error { return nil } -// This method requires the EventPublisher.lock is held +// This method requires the subscriptions.lock.RLock is held (the read-only lock) func (s *subscriptions) closeSubscriptionsForTokens(tokens memdb.ResultIterator) { for token := tokens.Next(); token != nil; token = tokens.Next() { token := token.(*structs.ACLToken) - if subs, ok := s.subsByToken[token.SecretID]; ok { + if subs, ok := s.byToken[token.SecretID]; ok { for _, sub := range subs { - sub.CloseReload() + sub.ForceReload() } } } @@ -218,8 +222,8 @@ func (s *subscriptions) closeSubscriptionsForTokens(tokens memdb.ResultIterator) // decides it can no longer maintain correct operation for example if ACL // policies changed or the state store was restored. // -// When the called is finished with the subscription for any reason, it must -// call Unsubscribe to free ACL tracking resources. +// When the caller is finished with the subscription for any reason, it must +// call Subscription.Unsubscribe to free ACL tracking resources. func (e *EventPublisher) Subscribe( ctx context.Context, req *stream.SubscribeRequest, @@ -278,7 +282,12 @@ func (e *EventPublisher) Subscribe( } e.subscriptions.add(req, sub) - + // Set unsubscribe so that the caller doesn't need to keep track of the + // SubscriptionRequest, and can not accidentally call unsubscribe with the + // wrong value. + sub.Unsubscribe = func() { + e.subscriptions.unsubscribe(req) + } return sub, nil } @@ -286,28 +295,29 @@ func (s *subscriptions) add(req *stream.SubscribeRequest, sub *stream.Subscripti s.lock.Lock() defer s.lock.Unlock() - subsByToken, ok := s.subsByToken[req.Token] + subsByToken, ok := s.byToken[req.Token] if !ok { subsByToken = make(map[*stream.SubscribeRequest]*stream.Subscription) - s.subsByToken[req.Token] = subsByToken + s.byToken[req.Token] = subsByToken } subsByToken[req] = sub } -// Unsubscribe must be called when a client is no longer interested in a -// subscription to free resources monitoring changes in it's ACL token. The same -// request object passed to Subscribe must be used. -func (s *subscriptions) Unsubscribe(req *stream.SubscribeRequest) { +// unsubscribe must be called when a client is no longer interested in a +// subscription to free resources monitoring changes in it's ACL token. +// +// req MUST be the same pointer that was used to register the subscription. +func (s *subscriptions) unsubscribe(req *stream.SubscribeRequest) { s.lock.Lock() defer s.lock.Unlock() - subsByToken, ok := s.subsByToken[req.Token] + subsByToken, ok := s.byToken[req.Token] if !ok { return } delete(subsByToken, req) if len(subsByToken) == 0 { - delete(s.subsByToken, req.Token) + delete(s.byToken, req.Token) } } diff --git a/agent/consul/stream/subscription.go b/agent/consul/stream/subscription.go index 4f2fa1eee5..a129ddef15 100644 --- a/agent/consul/stream/subscription.go +++ b/agent/consul/stream/subscription.go @@ -42,6 +42,10 @@ type Subscription struct { // cancelFn stores the context cancel function that will wake up the // in-progress Next call on a server-initiated state change e.g. Reload. cancelFn func() + + // Unsubscribe is a function set by EventPublisher that is called to + // free resources when the subscription is no longer needed. + Unsubscribe func() } type SubscribeRequest struct { @@ -116,9 +120,9 @@ func (s *Subscription) Next() ([]Event, error) { } } -// CloseReload closes the stream and signals that the subscriber should reload. +// ForceReload closes the stream and signals that the subscriber should reload. // It is safe to call from any goroutine. -func (s *Subscription) CloseReload() { +func (s *Subscription) ForceReload() { swapped := atomic.CompareAndSwapUint32(&s.state, SubscriptionStateOpen, SubscriptionStateCloseReload) @@ -126,9 +130,3 @@ func (s *Subscription) CloseReload() { s.cancelFn() } } - -// Request returns the request object that started the subscription. -// TODO: remove -func (s *Subscription) Request() *SubscribeRequest { - return s.req -} diff --git a/agent/consul/stream/subscription_test.go b/agent/consul/stream/subscription_test.go index 56de6958fb..cd55785917 100644 --- a/agent/consul/stream/subscription_test.go +++ b/agent/consul/stream/subscription_test.go @@ -118,11 +118,11 @@ func TestSubscriptionCloseReload(t *testing.T) { require.Len(t, got, 1) require.Equal(t, index, got[0].Index) - // Schedule a CloseReload simulating the server deciding this subscroption + // Schedule a ForceReload simulating the server deciding this subscroption // needs to reset (e.g. on ACL perm change). start = time.Now() time.AfterFunc(200*time.Millisecond, func() { - sub.CloseReload() + sub.ForceReload() }) _, err = sub.Next()