mirror of https://github.com/hashicorp/consul
rpc: refactor sessionTimers and fix racy tests
The sessionTimers map was secured by a lock which wasn't used properly in the tests. This lead to data races and failing tests when accessing the length or the members of the map. This patch adds a separate SessionTimers struct which is safe for concurrent use and which ecapsulates the behavior of the sessionTimers map.pull/3241/head
parent
05f756853e
commit
13eeeb720d
|
@ -172,8 +172,7 @@ type Server struct {
|
|||
// sessionTimers track the expiration time of each Session that has
|
||||
// a TTL. On expiration, a SessionDestroy event will occur, and
|
||||
// destroy the session via standard session destroy processing
|
||||
sessionTimers map[string]*time.Timer
|
||||
sessionTimersLock sync.Mutex
|
||||
sessionTimers *SessionTimers
|
||||
|
||||
// statsFetcher is used by autopilot to check the status of the other
|
||||
// Consul servers.
|
||||
|
@ -296,6 +295,7 @@ func NewServerLogger(config *Config, logger *log.Logger) (*Server, error) {
|
|||
rpcServer: rpc.NewServer(),
|
||||
rpcTLS: incomingTLS,
|
||||
reassertLeaderCh: make(chan chan error),
|
||||
sessionTimers: NewSessionTimers(),
|
||||
tombstoneGC: gc,
|
||||
shutdownCh: shutdownCh,
|
||||
}
|
||||
|
|
|
@ -514,7 +514,7 @@ func TestSession_ApplyTimers(t *testing.T) {
|
|||
}
|
||||
|
||||
// Check the session map
|
||||
if _, ok := s1.sessionTimers[out]; !ok {
|
||||
if s1.sessionTimers.Get(out) == nil {
|
||||
t.Fatalf("missing session timer")
|
||||
}
|
||||
|
||||
|
@ -526,7 +526,7 @@ func TestSession_ApplyTimers(t *testing.T) {
|
|||
}
|
||||
|
||||
// Check the session map
|
||||
if _, ok := s1.sessionTimers[out]; ok {
|
||||
if s1.sessionTimers.Get(out) != nil {
|
||||
t.Fatalf("session timer exists")
|
||||
}
|
||||
}
|
||||
|
@ -564,7 +564,7 @@ func TestSession_Renew(t *testing.T) {
|
|||
}
|
||||
|
||||
// Verify the timer map is setup
|
||||
if len(s1.sessionTimers) != 5 {
|
||||
if s1.sessionTimers.Len() != 5 {
|
||||
t.Fatalf("missing session timers")
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SessionTimers provides a map of named timers which
|
||||
// is safe for concurrent use.
|
||||
type SessionTimers struct {
|
||||
sync.RWMutex
|
||||
m map[string]*time.Timer
|
||||
}
|
||||
|
||||
func NewSessionTimers() *SessionTimers {
|
||||
return &SessionTimers{m: make(map[string]*time.Timer)}
|
||||
}
|
||||
|
||||
// Get returns the timer with the given id or nil.
|
||||
func (t *SessionTimers) Get(id string) *time.Timer {
|
||||
t.RLock()
|
||||
defer t.RUnlock()
|
||||
return t.m[id]
|
||||
}
|
||||
|
||||
// Set stores the timer under given id. If tm is nil the timer
|
||||
// witht the given id is removed.
|
||||
func (t *SessionTimers) Set(id string, tm *time.Timer) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
if tm == nil {
|
||||
// todo(fs): shouldn't we call Stop() here?
|
||||
delete(t.m, id)
|
||||
} else {
|
||||
t.m[id] = tm
|
||||
}
|
||||
}
|
||||
|
||||
// Del removes the timer with the given id.
|
||||
func (t *SessionTimers) Del(id string) {
|
||||
t.Set(id, nil)
|
||||
}
|
||||
|
||||
// Len returns the number of registered timers.
|
||||
func (t *SessionTimers) Len() int {
|
||||
t.RLock()
|
||||
defer t.RUnlock()
|
||||
return len(t.m)
|
||||
}
|
||||
|
||||
// ResetOrCreate sets the ttl of the timer with the given id or creates a new
|
||||
// one if it does not exist.
|
||||
func (t *SessionTimers) ResetOrCreate(id string, ttl time.Duration, afterFunc func()) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
|
||||
if tm := t.m[id]; tm != nil {
|
||||
tm.Reset(ttl)
|
||||
return
|
||||
}
|
||||
t.m[id] = time.AfterFunc(ttl, afterFunc)
|
||||
}
|
||||
|
||||
// Stop stops the timer with the given id and removes it.
|
||||
func (t *SessionTimers) Stop(id string) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
if tm := t.m[id]; tm != nil {
|
||||
tm.Stop()
|
||||
delete(t.m, id)
|
||||
}
|
||||
}
|
||||
|
||||
// StopAll stops and removes all registered timers.
|
||||
func (t *SessionTimers) StopAll() {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
for _, tm := range t.m {
|
||||
tm.Stop()
|
||||
}
|
||||
t.m = make(map[string]*time.Timer)
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSessionTimers(t *testing.T) {
|
||||
m := NewSessionTimers()
|
||||
ch := make(chan int)
|
||||
newTm := func(d time.Duration) *time.Timer {
|
||||
return time.AfterFunc(d, func() { ch <- 1 })
|
||||
}
|
||||
|
||||
waitForTimer := func() {
|
||||
select {
|
||||
case <-ch:
|
||||
return
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("timer did not fire")
|
||||
}
|
||||
}
|
||||
|
||||
// check that non-existent id returns nil
|
||||
if got, want := m.Get("foo"), (*time.Timer)(nil); got != want {
|
||||
t.Fatalf("got %v want %v", got, want)
|
||||
}
|
||||
|
||||
// add a timer and look it up and delete via Set(id, nil)
|
||||
tm := newTm(time.Millisecond)
|
||||
m.Set("foo", tm)
|
||||
if got, want := m.Len(), 1; got != want {
|
||||
t.Fatalf("got len %d want %d", got, want)
|
||||
}
|
||||
if got, want := m.Get("foo"), tm; got != want {
|
||||
t.Fatalf("got %v want %v", got, want)
|
||||
}
|
||||
m.Set("foo", nil)
|
||||
if got, want := m.Get("foo"), (*time.Timer)(nil); got != want {
|
||||
t.Fatalf("got %v want %v", got, want)
|
||||
}
|
||||
waitForTimer()
|
||||
|
||||
// same thing via Del(id)
|
||||
tm = newTm(time.Millisecond)
|
||||
m.Set("foo", tm)
|
||||
if got, want := m.Get("foo"), tm; got != want {
|
||||
t.Fatalf("got %v want %v", got, want)
|
||||
}
|
||||
m.Del("foo")
|
||||
if got, want := m.Len(), 0; got != want {
|
||||
t.Fatalf("got len %d want %d", got, want)
|
||||
}
|
||||
waitForTimer()
|
||||
|
||||
// create timer via ResetOrCreate
|
||||
m.ResetOrCreate("foo", time.Millisecond, func() { ch <- 1 })
|
||||
if got, want := m.Len(), 1; got != want {
|
||||
t.Fatalf("got len %d want %d", got, want)
|
||||
}
|
||||
waitForTimer()
|
||||
|
||||
// timer is still there
|
||||
if got, want := m.Len(), 1; got != want {
|
||||
t.Fatalf("got len %d want %d", got, want)
|
||||
}
|
||||
|
||||
// reset the timer and check that it fires again
|
||||
m.ResetOrCreate("foo", time.Millisecond, nil)
|
||||
waitForTimer()
|
||||
|
||||
// reset the timer with a long ttl and then stop it
|
||||
m.ResetOrCreate("foo", 20*time.Millisecond, func() { ch <- 1 })
|
||||
m.Stop("foo")
|
||||
select {
|
||||
case <-ch:
|
||||
t.Fatal("timer fired although it shouldn't")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// want
|
||||
}
|
||||
|
||||
// stopping a stopped timer should not break
|
||||
m.Stop("foo")
|
||||
|
||||
// stop should also remove the timer
|
||||
if got, want := m.Len(), 0; got != want {
|
||||
t.Fatalf("got len %d want %d", got, want)
|
||||
}
|
||||
|
||||
// create two timers and stop and then stop all
|
||||
m.ResetOrCreate("foo1", 20*time.Millisecond, func() { ch <- 1 })
|
||||
m.ResetOrCreate("foo2", 30*time.Millisecond, func() { ch <- 2 })
|
||||
m.StopAll()
|
||||
select {
|
||||
case x := <-ch:
|
||||
t.Fatalf("timer %d fired although it shouldn't", x)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// want
|
||||
}
|
||||
|
||||
// stopall should remove all timers
|
||||
if got, want := m.Len(), 0; got != want {
|
||||
t.Fatalf("got len %d want %d", got, want)
|
||||
}
|
||||
}
|
|
@ -66,49 +66,28 @@ func (s *Server) resetSessionTimer(id string, session *structs.Session) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Reset the session timer
|
||||
s.sessionTimersLock.Lock()
|
||||
defer s.sessionTimersLock.Unlock()
|
||||
s.resetSessionTimerLocked(id, ttl)
|
||||
s.createSessionTimer(session.ID, ttl)
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetSessionTimerLocked is used to reset a session timer
|
||||
// assuming the sessionTimerLock is already held
|
||||
func (s *Server) resetSessionTimerLocked(id string, ttl time.Duration) {
|
||||
// Ensure a timer map exists
|
||||
if s.sessionTimers == nil {
|
||||
s.sessionTimers = make(map[string]*time.Timer)
|
||||
}
|
||||
|
||||
func (s *Server) createSessionTimer(id string, ttl time.Duration) {
|
||||
// Reset the session timer
|
||||
// Adjust the given TTL by the TTL multiplier. This is done
|
||||
// to give a client a grace period and to compensate for network
|
||||
// and processing delays. The contract is that a session is not expired
|
||||
// before the TTL, but there is no explicit promise about the upper
|
||||
// bound so this is allowable.
|
||||
ttl = ttl * structs.SessionTTLMultiplier
|
||||
|
||||
// Renew the session timer if it exists
|
||||
if timer, ok := s.sessionTimers[id]; ok {
|
||||
timer.Reset(ttl)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a new timer to track expiration of this ssession
|
||||
timer := time.AfterFunc(ttl, func() {
|
||||
s.invalidateSession(id)
|
||||
})
|
||||
s.sessionTimers[id] = timer
|
||||
s.sessionTimers.ResetOrCreate(id, ttl, func() { s.invalidateSession(id) })
|
||||
}
|
||||
|
||||
// invalidateSession is invoked when a session TTL is reached and we
|
||||
// need to invalidate the session.
|
||||
func (s *Server) invalidateSession(id string) {
|
||||
defer metrics.MeasureSince([]string{"consul", "session_ttl", "invalidate"}, time.Now())
|
||||
|
||||
// Clear the session timer
|
||||
s.sessionTimersLock.Lock()
|
||||
delete(s.sessionTimers, id)
|
||||
s.sessionTimersLock.Unlock()
|
||||
s.sessionTimers.Del(id)
|
||||
|
||||
// Create a session destroy request
|
||||
args := structs.SessionRequest{
|
||||
|
@ -137,26 +116,14 @@ func (s *Server) invalidateSession(id string) {
|
|||
// a single session. This is used when a session is destroyed
|
||||
// explicitly and no longer needed.
|
||||
func (s *Server) clearSessionTimer(id string) error {
|
||||
s.sessionTimersLock.Lock()
|
||||
defer s.sessionTimersLock.Unlock()
|
||||
|
||||
if timer, ok := s.sessionTimers[id]; ok {
|
||||
timer.Stop()
|
||||
delete(s.sessionTimers, id)
|
||||
}
|
||||
s.sessionTimers.Stop(id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// clearAllSessionTimers is used when a leader is stepping
|
||||
// down and we no longer need to track any session timers.
|
||||
func (s *Server) clearAllSessionTimers() error {
|
||||
s.sessionTimersLock.Lock()
|
||||
defer s.sessionTimersLock.Unlock()
|
||||
|
||||
for _, t := range s.sessionTimers {
|
||||
t.Stop()
|
||||
}
|
||||
s.sessionTimers = nil
|
||||
s.sessionTimers.StopAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -166,10 +133,7 @@ func (s *Server) sessionStats() {
|
|||
for {
|
||||
select {
|
||||
case <-time.After(5 * time.Second):
|
||||
s.sessionTimersLock.Lock()
|
||||
num := len(s.sessionTimers)
|
||||
s.sessionTimersLock.Unlock()
|
||||
metrics.SetGauge([]string{"consul", "session_ttl", "active"}, float32(num))
|
||||
metrics.SetGauge([]string{"consul", "session_ttl", "active"}, float32(s.sessionTimers.Len()))
|
||||
|
||||
case <-s.shutdownCh:
|
||||
return
|
||||
|
|
|
@ -39,8 +39,7 @@ func TestInitializeSessionTimers(t *testing.T) {
|
|||
}
|
||||
|
||||
// Check that we have a timer
|
||||
_, ok := s1.sessionTimers[session.ID]
|
||||
if !ok {
|
||||
if s1.sessionTimers.Get(session.ID) == nil {
|
||||
t.Fatalf("missing session timer")
|
||||
}
|
||||
}
|
||||
|
@ -79,8 +78,7 @@ func TestResetSessionTimer_Fault(t *testing.T) {
|
|||
}
|
||||
|
||||
// Check that we have a timer
|
||||
_, ok := s1.sessionTimers[session.ID]
|
||||
if !ok {
|
||||
if s1.sessionTimers.Get(session.ID) == nil {
|
||||
t.Fatalf("missing session timer")
|
||||
}
|
||||
}
|
||||
|
@ -113,8 +111,7 @@ func TestResetSessionTimer_NoTTL(t *testing.T) {
|
|||
}
|
||||
|
||||
// Check that we have a timer
|
||||
_, ok := s1.sessionTimers[session.ID]
|
||||
if ok {
|
||||
if s1.sessionTimers.Get(session.ID) != nil {
|
||||
t.Fatalf("should not have session timer")
|
||||
}
|
||||
}
|
||||
|
@ -145,17 +142,13 @@ func TestResetSessionTimerLocked(t *testing.T) {
|
|||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
s1.sessionTimersLock.Lock()
|
||||
s1.resetSessionTimerLocked("foo", 5*time.Millisecond)
|
||||
s1.sessionTimersLock.Unlock()
|
||||
|
||||
if _, ok := s1.sessionTimers["foo"]; !ok {
|
||||
s1.createSessionTimer("foo", 5*time.Millisecond)
|
||||
if s1.sessionTimers.Get("foo") == nil {
|
||||
t.Fatalf("missing timer")
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond * structs.SessionTTLMultiplier)
|
||||
|
||||
if _, ok := s1.sessionTimers["foo"]; ok {
|
||||
if s1.sessionTimers.Get("foo") != nil {
|
||||
t.Fatalf("timer should be gone")
|
||||
}
|
||||
}
|
||||
|
@ -165,39 +158,46 @@ func TestResetSessionTimerLocked_Renew(t *testing.T) {
|
|||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
ttl := 100 * time.Millisecond
|
||||
|
||||
s1.sessionTimersLock.Lock()
|
||||
s1.resetSessionTimerLocked("foo", 5*time.Millisecond)
|
||||
s1.sessionTimersLock.Unlock()
|
||||
|
||||
if _, ok := s1.sessionTimers["foo"]; !ok {
|
||||
// create the timer
|
||||
s1.createSessionTimer("foo", ttl)
|
||||
if s1.sessionTimers.Get("foo") == nil {
|
||||
t.Fatalf("missing timer")
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
// wait until it is "expired" but at this point
|
||||
// the session still exists.
|
||||
time.Sleep(ttl)
|
||||
if s1.sessionTimers.Get("foo") == nil {
|
||||
t.Fatal("missing timer")
|
||||
}
|
||||
|
||||
// Renew the session
|
||||
s1.sessionTimersLock.Lock()
|
||||
renew := time.Now()
|
||||
s1.resetSessionTimerLocked("foo", 5*time.Millisecond)
|
||||
s1.sessionTimersLock.Unlock()
|
||||
// renew the session which will reset the TTL to 2*ttl
|
||||
// since that is the current SessionTTLMultiplier
|
||||
s1.createSessionTimer("foo", ttl)
|
||||
|
||||
// Watch for invalidation
|
||||
for time.Now().Sub(renew) < 20*time.Millisecond {
|
||||
s1.sessionTimersLock.Lock()
|
||||
_, ok := s1.sessionTimers["foo"]
|
||||
s1.sessionTimersLock.Unlock()
|
||||
if !ok {
|
||||
end := time.Now()
|
||||
if end.Sub(renew) < 5*time.Millisecond {
|
||||
t.Fatalf("early invalidate")
|
||||
}
|
||||
return
|
||||
renew := time.Now()
|
||||
deadline := renew.Add(2 * structs.SessionTTLMultiplier * ttl)
|
||||
for {
|
||||
now := time.Now()
|
||||
if now.After(deadline) {
|
||||
t.Fatal("should have expired by now")
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
// timer still exists
|
||||
if s1.sessionTimers.Get("foo") != nil {
|
||||
time.Sleep(time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
// timer gone
|
||||
if now.Sub(renew) < ttl {
|
||||
t.Fatalf("early invalidate")
|
||||
}
|
||||
break
|
||||
}
|
||||
t.Fatalf("should have expired")
|
||||
}
|
||||
|
||||
func TestInvalidateSession(t *testing.T) {
|
||||
|
@ -239,16 +239,14 @@ func TestClearSessionTimer(t *testing.T) {
|
|||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
|
||||
s1.sessionTimersLock.Lock()
|
||||
s1.resetSessionTimerLocked("foo", 5*time.Millisecond)
|
||||
s1.sessionTimersLock.Unlock()
|
||||
s1.createSessionTimer("foo", 5*time.Millisecond)
|
||||
|
||||
err := s1.clearSessionTimer("foo")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := s1.sessionTimers["foo"]; ok {
|
||||
if s1.sessionTimers.Get("foo") != nil {
|
||||
t.Fatalf("timer should be gone")
|
||||
}
|
||||
}
|
||||
|
@ -258,18 +256,17 @@ func TestClearAllSessionTimers(t *testing.T) {
|
|||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
|
||||
s1.sessionTimersLock.Lock()
|
||||
s1.resetSessionTimerLocked("foo", 10*time.Millisecond)
|
||||
s1.resetSessionTimerLocked("bar", 10*time.Millisecond)
|
||||
s1.resetSessionTimerLocked("baz", 10*time.Millisecond)
|
||||
s1.sessionTimersLock.Unlock()
|
||||
s1.createSessionTimer("foo", 10*time.Millisecond)
|
||||
s1.createSessionTimer("bar", 10*time.Millisecond)
|
||||
s1.createSessionTimer("baz", 10*time.Millisecond)
|
||||
|
||||
err := s1.clearAllSessionTimers()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if len(s1.sessionTimers) != 0 {
|
||||
// sessionTimers is guarded by the lock
|
||||
if s1.sessionTimers.Len() != 0 {
|
||||
t.Fatalf("timers should be gone")
|
||||
}
|
||||
}
|
||||
|
@ -297,7 +294,7 @@ func TestServer_SessionTTL_Failover(t *testing.T) {
|
|||
var leader *Server
|
||||
for _, s := range servers {
|
||||
// Check that s.sessionTimers is empty
|
||||
if len(s.sessionTimers) != 0 {
|
||||
if s.sessionTimers.Len() != 0 {
|
||||
t.Fatalf("should have no sessionTimers")
|
||||
}
|
||||
// Find the leader too
|
||||
|
@ -338,7 +335,7 @@ func TestServer_SessionTTL_Failover(t *testing.T) {
|
|||
}
|
||||
|
||||
// Check that sessionTimers has the session ID
|
||||
if _, ok := leader.sessionTimers[id1]; !ok {
|
||||
if leader.sessionTimers.Get(id1) == nil {
|
||||
t.Fatalf("missing session timer")
|
||||
}
|
||||
|
||||
|
@ -346,12 +343,11 @@ func TestServer_SessionTTL_Failover(t *testing.T) {
|
|||
leader.Shutdown()
|
||||
|
||||
// sessionTimers should be cleared on leader shutdown
|
||||
if len(leader.sessionTimers) != 0 {
|
||||
if leader.sessionTimers.Len() != 0 {
|
||||
t.Fatalf("session timers should be empty on the shutdown leader")
|
||||
}
|
||||
// Find the new leader
|
||||
retry.Run(t, func(r *retry.R) {
|
||||
|
||||
leader = nil
|
||||
for _, s := range servers {
|
||||
if s.IsLeader() {
|
||||
|
@ -363,7 +359,7 @@ func TestServer_SessionTTL_Failover(t *testing.T) {
|
|||
}
|
||||
|
||||
// Ensure session timer is restored
|
||||
if _, ok := leader.sessionTimers[id1]; !ok {
|
||||
if leader.sessionTimers.Get(id1) == nil {
|
||||
r.Fatal("missing session timer")
|
||||
}
|
||||
})
|
||||
|
|
|
@ -211,10 +211,10 @@ func TestSnapshot_LeaderState(t *testing.T) {
|
|||
}
|
||||
|
||||
// Make sure the leader has timers setup.
|
||||
if _, ok := s1.sessionTimers[before]; !ok {
|
||||
if s1.sessionTimers.Get(before) == nil {
|
||||
t.Fatalf("missing session timer")
|
||||
}
|
||||
if _, ok := s1.sessionTimers[after]; !ok {
|
||||
if s1.sessionTimers.Get(after) == nil {
|
||||
t.Fatalf("missing session timer")
|
||||
}
|
||||
|
||||
|
@ -229,10 +229,10 @@ func TestSnapshot_LeaderState(t *testing.T) {
|
|||
|
||||
// Make sure the before time is still there, and that the after timer
|
||||
// got reverted. This proves we fully cycled the leader state.
|
||||
if _, ok := s1.sessionTimers[before]; !ok {
|
||||
if s1.sessionTimers.Get(before) == nil {
|
||||
t.Fatalf("missing session timer")
|
||||
}
|
||||
if _, ok := s1.sessionTimers[after]; ok {
|
||||
if s1.sessionTimers.Get(after) != nil {
|
||||
t.Fatalf("unexpected session timer")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue