diff --git a/command/agent/session_endpoint_test.go b/command/agent/session_endpoint_test.go index 1843bcb837..edfa074402 100644 --- a/command/agent/session_endpoint_test.go +++ b/command/agent/session_endpoint_test.go @@ -255,6 +255,84 @@ func TestSessionTTL(t *testing.T) { }) } +func TestSessionBadTTL(t *testing.T) { + httpTest(t, func(srv *HTTPServer) { + badTTL := "10z" + + // Create Session with illegal TTL + body := bytes.NewBuffer(nil) + enc := json.NewEncoder(body) + raw := map[string]interface{}{ + "TTL": badTTL, + } + enc.Encode(raw) + + req, err := http.NewRequest("PUT", "/v1/session/create", body) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := httptest.NewRecorder() + obj, err := srv.SessionCreate(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if obj != nil { + t.Fatalf("illegal TTL '%s' allowed", badTTL) + } + if resp.Code != 400 { + t.Fatalf("Bad response code, should be 400") + } + + // less than SessionTTLMin + body = bytes.NewBuffer(nil) + enc = json.NewEncoder(body) + raw = map[string]interface{}{ + "TTL": "5s", + } + enc.Encode(raw) + + req, err = http.NewRequest("PUT", "/v1/session/create", body) + if err != nil { + t.Fatalf("err: %v", err) + } + resp = httptest.NewRecorder() + obj, err = srv.SessionCreate(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if obj != nil { + t.Fatalf("illegal TTL '%s' allowed", badTTL) + } + if resp.Code != 400 { + t.Fatalf("Bad response code, should be 400") + } + + // more than SessionTTLMax + body = bytes.NewBuffer(nil) + enc = json.NewEncoder(body) + raw = map[string]interface{}{ + "TTL": "4000s", + } + enc.Encode(raw) + + req, err = http.NewRequest("PUT", "/v1/session/create", body) + if err != nil { + t.Fatalf("err: %v", err) + } + resp = httptest.NewRecorder() + obj, err = srv.SessionCreate(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if obj != nil { + t.Fatalf("illegal TTL '%s' allowed", badTTL) + } + if resp.Code != 400 { + t.Fatalf("Bad response code, should be 400") + } + }) +} + func TestSessionTTLRenew(t *testing.T) { httpTest(t, func(srv *HTTPServer) { TTL := "10s" // use the minimum legal ttl diff --git a/consul/session_ttl.go b/consul/session_ttl.go index 818573aeb0..06c572be25 100644 --- a/consul/session_ttl.go +++ b/consul/session_ttl.go @@ -26,6 +26,21 @@ func (s *Server) initializeSessionTimers() error { return nil } +// invalidate the session when timer expires, called by AfterFunc +func (s *Server) invalidateSession(id string) { + args := structs.SessionRequest{ + Datacenter: s.config.Datacenter, + Op: structs.SessionDestroy, + } + args.Session.ID = id + + // Apply the update to destroy the session + _, err := s.raftApply(structs.SessionRequestType, args) + if err != nil { + s.logger.Printf("[ERR] consul.session: Apply failed: %v", err) + } +} + func (s *Server) resetSessionTimer(id string, session *structs.Session) error { if session == nil { var err error @@ -60,17 +75,7 @@ func (s *Server) resetSessionTimer(id string, session *structs.Session) error { t.Reset(ttl * structs.SessionTTLMultiplier) } else { s.sessionTimers[session.ID] = time.AfterFunc(ttl*structs.SessionTTLMultiplier, func() { - args := structs.SessionRequest{ - Datacenter: s.config.Datacenter, - Op: structs.SessionDestroy, - } - args.Session.ID = session.ID - - // Apply the update to destroy the session - _, err := s.raftApply(structs.SessionRequestType, args) - if err != nil { - s.logger.Printf("[ERR] consul.session: Apply failed: %v", err) - } + s.invalidateSession(session.ID) }) } @@ -85,7 +90,6 @@ func (s *Server) clearSessionTimer(id string) error { s.sessionTimers[id].Stop() delete(s.sessionTimers, id) } - s.sessionTimers = nil return nil }