mirror of https://github.com/hashicorp/consul
Fixed clearSessionTimer, created invalidateSession, added invalid TTL test
parent
60915629f6
commit
ac54010027
|
@ -255,6 +255,84 @@ func TestSessionTTL(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestSessionBadTTL(t *testing.T) {
|
||||
httpTest(t, func(srv *HTTPServer) {
|
||||
badTTL := "10z"
|
||||
|
||||
// Create Session with illegal TTL
|
||||
body := bytes.NewBuffer(nil)
|
||||
enc := json.NewEncoder(body)
|
||||
raw := map[string]interface{}{
|
||||
"TTL": badTTL,
|
||||
}
|
||||
enc.Encode(raw)
|
||||
|
||||
req, err := http.NewRequest("PUT", "/v1/session/create", body)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := srv.SessionCreate(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if obj != nil {
|
||||
t.Fatalf("illegal TTL '%s' allowed", badTTL)
|
||||
}
|
||||
if resp.Code != 400 {
|
||||
t.Fatalf("Bad response code, should be 400")
|
||||
}
|
||||
|
||||
// less than SessionTTLMin
|
||||
body = bytes.NewBuffer(nil)
|
||||
enc = json.NewEncoder(body)
|
||||
raw = map[string]interface{}{
|
||||
"TTL": "5s",
|
||||
}
|
||||
enc.Encode(raw)
|
||||
|
||||
req, err = http.NewRequest("PUT", "/v1/session/create", body)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
resp = httptest.NewRecorder()
|
||||
obj, err = srv.SessionCreate(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if obj != nil {
|
||||
t.Fatalf("illegal TTL '%s' allowed", badTTL)
|
||||
}
|
||||
if resp.Code != 400 {
|
||||
t.Fatalf("Bad response code, should be 400")
|
||||
}
|
||||
|
||||
// more than SessionTTLMax
|
||||
body = bytes.NewBuffer(nil)
|
||||
enc = json.NewEncoder(body)
|
||||
raw = map[string]interface{}{
|
||||
"TTL": "4000s",
|
||||
}
|
||||
enc.Encode(raw)
|
||||
|
||||
req, err = http.NewRequest("PUT", "/v1/session/create", body)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
resp = httptest.NewRecorder()
|
||||
obj, err = srv.SessionCreate(resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if obj != nil {
|
||||
t.Fatalf("illegal TTL '%s' allowed", badTTL)
|
||||
}
|
||||
if resp.Code != 400 {
|
||||
t.Fatalf("Bad response code, should be 400")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSessionTTLRenew(t *testing.T) {
|
||||
httpTest(t, func(srv *HTTPServer) {
|
||||
TTL := "10s" // use the minimum legal ttl
|
||||
|
|
|
@ -26,6 +26,21 @@ func (s *Server) initializeSessionTimers() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// invalidate the session when timer expires, called by AfterFunc
|
||||
func (s *Server) invalidateSession(id string) {
|
||||
args := structs.SessionRequest{
|
||||
Datacenter: s.config.Datacenter,
|
||||
Op: structs.SessionDestroy,
|
||||
}
|
||||
args.Session.ID = id
|
||||
|
||||
// Apply the update to destroy the session
|
||||
_, err := s.raftApply(structs.SessionRequestType, args)
|
||||
if err != nil {
|
||||
s.logger.Printf("[ERR] consul.session: Apply failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) resetSessionTimer(id string, session *structs.Session) error {
|
||||
if session == nil {
|
||||
var err error
|
||||
|
@ -60,17 +75,7 @@ func (s *Server) resetSessionTimer(id string, session *structs.Session) error {
|
|||
t.Reset(ttl * structs.SessionTTLMultiplier)
|
||||
} else {
|
||||
s.sessionTimers[session.ID] = time.AfterFunc(ttl*structs.SessionTTLMultiplier, func() {
|
||||
args := structs.SessionRequest{
|
||||
Datacenter: s.config.Datacenter,
|
||||
Op: structs.SessionDestroy,
|
||||
}
|
||||
args.Session.ID = session.ID
|
||||
|
||||
// Apply the update to destroy the session
|
||||
_, err := s.raftApply(structs.SessionRequestType, args)
|
||||
if err != nil {
|
||||
s.logger.Printf("[ERR] consul.session: Apply failed: %v", err)
|
||||
}
|
||||
s.invalidateSession(session.ID)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -85,7 +90,6 @@ func (s *Server) clearSessionTimer(id string) error {
|
|||
s.sessionTimers[id].Stop()
|
||||
delete(s.sessionTimers, id)
|
||||
}
|
||||
s.sessionTimers = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue