diff --git a/api/lock.go b/api/lock.go index b8abbd476d..2b5aa2a290 100644 --- a/api/lock.go +++ b/api/lock.go @@ -130,7 +130,8 @@ func (l *Lock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { } else { l.sessionRenew = make(chan struct{}) l.lockSession = s - go l.renewSession(s, l.sessionRenew) + session := l.c.Session() + go session.RenewPeriodic(l.opts.SessionTTL, s, nil, l.sessionRenew) // If we fail to acquire the lock, cleanup the session defer func() { @@ -302,30 +303,6 @@ func (l *Lock) lockEntry(session string) *KVPair { } } -// renewSession is a long running routine that maintians a session -// by doing a periodic Session renewal. -func (l *Lock) renewSession(id string, doneCh chan struct{}) { - session := l.c.Session() - ttl, _ := time.ParseDuration(l.opts.SessionTTL) - for { - select { - case <-time.After(ttl / 2): - entry, _, err := session.Renew(id, nil) - if err != nil || entry == nil { - return - } - - // Handle the server updating the TTL - ttl, _ = time.ParseDuration(entry.TTL) - - case <-doneCh: - // Attempt a session destroy - session.Destroy(id, nil) - return - } - } -} - // monitorLock is a long running routine to monitor a lock ownership // It closes the stopCh if we lose our leadership. func (l *Lock) monitorLock(session string, stopCh chan struct{}) { diff --git a/api/semaphore.go b/api/semaphore.go index 7139c40dba..957f884a4d 100644 --- a/api/semaphore.go +++ b/api/semaphore.go @@ -155,7 +155,8 @@ func (s *Semaphore) Acquire(stopCh <-chan struct{}) (<-chan struct{}, error) { } else { s.sessionRenew = make(chan struct{}) s.lockSession = sess - go s.renewSession(sess, s.sessionRenew) + session := s.c.Session() + go session.RenewPeriodic(s.opts.SessionTTL, sess, nil, s.sessionRenew) // If we fail to acquire the lock, cleanup the session defer func() { @@ -384,30 +385,6 @@ func (s *Semaphore) createSession() (string, error) { return id, nil } -// renewSession is a long running routine that maintians a session -// by doing a periodic Session renewal. -func (s *Semaphore) renewSession(id string, doneCh chan struct{}) { - session := s.c.Session() - ttl, _ := time.ParseDuration(s.opts.SessionTTL) - for { - select { - case <-time.After(ttl / 2): - entry, _, err := session.Renew(id, nil) - if err != nil || entry == nil { - return - } - - // Handle the server updating the TTL - ttl, _ = time.ParseDuration(entry.TTL) - - case <-doneCh: - // Attempt a session destroy - session.Destroy(id, nil) - return - } - } -} - // contenderEntry returns a formatted KVPair for the contender func (s *Semaphore) contenderEntry(session string) *KVPair { return &KVPair{ diff --git a/api/session.go b/api/session.go index e889bbe0de..bb84644fd9 100644 --- a/api/session.go +++ b/api/session.go @@ -147,6 +147,36 @@ func (s *Session) Renew(id string, q *WriteOptions) (*SessionEntry, *WriteMeta, return nil, wm, nil } +// RenewPeriodic is used to periodically invoke Session.Renew on a +// session until a doneCh is closed. This is meant to be used in a long running +// goroutine to ensure a session stays valid. +func (s *Session) RenewPeriodic(initialTTL string, id string, q *WriteOptions, doneCh chan struct{}) error { + ttl, err := time.ParseDuration(initialTTL) + if err != nil { + return err + } + for { + select { + case <-time.After(ttl / 2): + entry, _, err := s.Renew(id, q) + if err != nil { + return err + } + if entry == nil { + return nil + } + + // Handle the server updating the TTL + ttl, _ = time.ParseDuration(entry.TTL) + + case <-doneCh: + // Attempt a session destroy + s.Destroy(id, q) + return nil + } + } +} + // Info looks up a single session func (s *Session) Info(id string, q *QueryOptions) (*SessionEntry, *QueryMeta, error) { r := s.c.newRequest("GET", "/v1/session/info/"+id)