diff --git a/command/agent/session_endpoint.go b/command/agent/session_endpoint.go index 760255f22c..cd9fa7ecdb 100644 --- a/command/agent/session_endpoint.go +++ b/command/agent/session_endpoint.go @@ -32,13 +32,14 @@ func (s *HTTPServer) SessionCreate(resp http.ResponseWriter, req *http.Request) return nil, nil } - // Default the session to our node + serf check + // Default the session to our node + serf check + release session invalidate behavior args := structs.SessionRequest{ Op: structs.SessionCreate, Session: structs.Session{ Node: s.agent.config.NodeName, Checks: []string{consul.SerfCheckID}, LockDelay: 15 * time.Second, + Behavior: structs.SessionKeysRelease, }, } s.parseDC(req, &args.Datacenter) diff --git a/command/agent/session_endpoint_test.go b/command/agent/session_endpoint_test.go index b5a93eea3a..5f32bc1737 100644 --- a/command/agent/session_endpoint_test.go +++ b/command/agent/session_endpoint_test.go @@ -59,6 +59,55 @@ func TestSessionCreate(t *testing.T) { }) } +func TestSessionCreateDelete(t *testing.T) { + httpTest(t, func(srv *HTTPServer) { + // Create a health check + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: srv.agent.config.NodeName, + Address: "127.0.0.1", + Check: &structs.HealthCheck{ + CheckID: "consul", + Node: srv.agent.config.NodeName, + Name: "consul", + ServiceID: "consul", + Status: structs.HealthPassing, + }, + } + var out struct{} + if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + + // Associate session with node and 2 health checks, and make it delete on session destroy + body := bytes.NewBuffer(nil) + enc := json.NewEncoder(body) + raw := map[string]interface{}{ + "Name": "my-cool-session", + "Node": srv.agent.config.NodeName, + "Checks": []string{consul.SerfCheckID, "consul"}, + "LockDelay": "20s", + "Behavior": "delete", + } + enc.Encode(raw) + + req, err := http.NewRequest("PUT", "/v1/session/create", body) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := httptest.NewRecorder() + obj, err := srv.SessionCreate(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + if _, ok := obj.(sessionCreateResponse); !ok { + t.Fatalf("should work") + } + }) +} + func TestFixupLockDelay(t *testing.T) { inp := map[string]interface{}{ "lockdelay": float64(15), @@ -105,6 +154,28 @@ func makeTestSession(t *testing.T, srv *HTTPServer) string { return sessResp.ID } +func makeTestSessionDelete(t *testing.T, srv *HTTPServer) string { + // Create Session with delete behavior + body := bytes.NewBuffer(nil) + enc := json.NewEncoder(body) + raw := map[string]interface{}{ + "Behavior": "delete", + } + enc.Encode(raw) + + req, err := http.NewRequest("PUT", "/v1/session/create", body) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := httptest.NewRecorder() + obj, err := srv.SessionCreate(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + sessResp := obj.(sessionCreateResponse) + return sessResp.ID +} + func TestSessionDestroy(t *testing.T) { httpTest(t, func(srv *HTTPServer) { id := makeTestSession(t, srv) @@ -188,3 +259,113 @@ func TestSessionsForNode(t *testing.T) { } }) } + +func TestSessionDeleteDestroy(t *testing.T) { + httpTest(t, func(srv *HTTPServer) { + id := makeTestSessionDelete(t, srv) + + // now create a new key for the session and acquire it + buf := bytes.NewBuffer([]byte("test")) + req, err := http.NewRequest("PUT", "/v1/kv/ephemeral?acquire="+id, buf) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := httptest.NewRecorder() + obj, err := srv.KVSEndpoint(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + if res := obj.(bool); !res { + t.Fatalf("should work") + } + + // now destroy the session, this should delete the key create above + req, err = http.NewRequest("PUT", "/v1/session/destroy/"+id, nil) + resp = httptest.NewRecorder() + obj, err = srv.SessionDestroy(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp := obj.(bool); !resp { + t.Fatalf("should work") + } + + // Verify that the key is gone + req, _ = http.NewRequest("GET", "/v1/kv/ephemeral", nil) + resp = httptest.NewRecorder() + obj, _ = srv.KVSEndpoint(resp, req) + res, found := obj.(structs.DirEntries) + if found || len(res) != 0 { + t.Fatalf("bad: %v found, should be nothing", res) + } + }) +} + +func TestSessionDeleteGet(t *testing.T) { + httpTest(t, func(srv *HTTPServer) { + id := makeTestSessionDelete(t, srv) + + req, err := http.NewRequest("GET", + "/v1/session/info/"+id, nil) + resp := httptest.NewRecorder() + obj, err := srv.SessionGet(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + respObj, ok := obj.(structs.Sessions) + if !ok { + t.Fatalf("should work") + } + if len(respObj) != 1 { + t.Fatalf("bad: %v", respObj) + } + }) +} + +func TestSessionDeleteList(t *testing.T) { + httpTest(t, func(srv *HTTPServer) { + var ids []string + for i := 0; i < 10; i++ { + ids = append(ids, makeTestSessionDelete(t, srv)) + } + + req, err := http.NewRequest("GET", "/v1/session/list", nil) + resp := httptest.NewRecorder() + obj, err := srv.SessionList(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + respObj, ok := obj.(structs.Sessions) + if !ok { + t.Fatalf("should work") + } + if len(respObj) != 10 { + t.Fatalf("bad: %v", respObj) + } + }) +} + +func TestSessionsDeleteForNode(t *testing.T) { + httpTest(t, func(srv *HTTPServer) { + var ids []string + for i := 0; i < 10; i++ { + ids = append(ids, makeTestSessionDelete(t, srv)) + } + + req, err := http.NewRequest("GET", + "/v1/session/node/"+srv.agent.config.NodeName, nil) + resp := httptest.NewRecorder() + obj, err := srv.SessionsForNode(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + respObj, ok := obj.(structs.Sessions) + if !ok { + t.Fatalf("should work") + } + if len(respObj) != 10 { + t.Fatalf("bad: %v", respObj) + } + }) +} diff --git a/consul/session_endpoint.go b/consul/session_endpoint.go index c08bef39ea..297e32f328 100644 --- a/consul/session_endpoint.go +++ b/consul/session_endpoint.go @@ -28,6 +28,9 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error { if args.Session.Node == "" && args.Op == structs.SessionCreate { return fmt.Errorf("Must provide Node") } + if args.Session.Behavior == "" { + args.Session.Behavior = structs.SessionKeysRelease // force default behavior + } // If this is a create, we must generate the Session ID. This must // be done prior to appending to the raft log, because the ID is not diff --git a/consul/session_endpoint_test.go b/consul/session_endpoint_test.go index ab3bfb8435..5e2b03caa9 100644 --- a/consul/session_endpoint_test.go +++ b/consul/session_endpoint_test.go @@ -66,6 +66,66 @@ func TestSessionEndpoint_Apply(t *testing.T) { } } +func TestSessionEndpoint_DeleteApply(t *testing.T) { + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + client := rpcClient(t, s1) + defer client.Close() + + testutil.WaitForLeader(t, client.Call, "dc1") + + // Just add a node + s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + + arg := structs.SessionRequest{ + Datacenter: "dc1", + Op: structs.SessionCreate, + Session: structs.Session{ + Node: "foo", + Name: "my-session", + Behavior: structs.SessionKeysDelete, + }, + } + var out string + if err := client.Call("Session.Apply", &arg, &out); err != nil { + t.Fatalf("err: %v", err) + } + id := out + + // Verify + state := s1.fsm.State() + _, s, err := state.SessionGet(out) + if err != nil { + t.Fatalf("err: %v", err) + } + if s == nil { + t.Fatalf("should not be nil") + } + if s.Node != "foo" { + t.Fatalf("bad: %v", s) + } + if s.Name != "my-session" { + t.Fatalf("bad: %v", s) + } + + // Do a delete + arg.Op = structs.SessionDestroy + arg.Session.ID = out + if err := client.Call("Session.Apply", &arg, &out); err != nil { + t.Fatalf("err: %v", err) + } + + // Verify + _, s, err = state.SessionGet(id) + if err != nil { + t.Fatalf("err: %v", err) + } + if s != nil { + t.Fatalf("bad: %v", s) + } +} + func TestSessionEndpoint_Get(t *testing.T) { dir1, s1 := testServer(t) defer os.RemoveAll(dir1) @@ -80,7 +140,51 @@ func TestSessionEndpoint_Get(t *testing.T) { Datacenter: "dc1", Op: structs.SessionCreate, Session: structs.Session{ - Node: "foo", + Node: "foo", + }, + } + var out string + if err := client.Call("Session.Apply", &arg, &out); err != nil { + t.Fatalf("err: %v", err) + } + + getR := structs.SessionSpecificRequest{ + Datacenter: "dc1", + Session: out, + } + var sessions structs.IndexedSessions + if err := client.Call("Session.Get", &getR, &sessions); err != nil { + t.Fatalf("err: %v", err) + } + + if sessions.Index == 0 { + t.Fatalf("Bad: %v", sessions) + } + if len(sessions.Sessions) != 1 { + t.Fatalf("Bad: %v", sessions) + } + s := sessions.Sessions[0] + if s.ID != out { + t.Fatalf("bad: %v", s) + } +} + +func TestSessionEndpoint_DeleteGet(t *testing.T) { + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + client := rpcClient(t, s1) + defer client.Close() + + testutil.WaitForLeader(t, client.Call, "dc1") + + s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + arg := structs.SessionRequest{ + Datacenter: "dc1", + Op: structs.SessionCreate, + Session: structs.Session{ + Node: "foo", + Behavior: structs.SessionKeysDelete, }, } var out string @@ -160,6 +264,58 @@ func TestSessionEndpoint_List(t *testing.T) { } } +func TestSessionEndpoint_DeleteList(t *testing.T) { + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + client := rpcClient(t, s1) + defer client.Close() + + testutil.WaitForLeader(t, client.Call, "dc1") + + s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + ids := []string{} + for i := 0; i < 5; i++ { + arg := structs.SessionRequest{ + Datacenter: "dc1", + Op: structs.SessionCreate, + Session: structs.Session{ + Node: "foo", + Behavior: structs.SessionKeysDelete, + }, + } + var out string + if err := client.Call("Session.Apply", &arg, &out); err != nil { + t.Fatalf("err: %v", err) + } + ids = append(ids, out) + } + + getR := structs.DCSpecificRequest{ + Datacenter: "dc1", + } + var sessions structs.IndexedSessions + if err := client.Call("Session.List", &getR, &sessions); err != nil { + t.Fatalf("err: %v", err) + } + + if sessions.Index == 0 { + t.Fatalf("Bad: %v", sessions) + } + if len(sessions.Sessions) != 5 { + t.Fatalf("Bad: %v", sessions.Sessions) + } + for i := 0; i < len(sessions.Sessions); i++ { + s := sessions.Sessions[i] + if !strContains(ids, s.ID) { + t.Fatalf("bad: %v", s) + } + if s.Node != "foo" { + t.Fatalf("bad: %v", s) + } + } +} + func TestSessionEndpoint_NodeSessions(t *testing.T) { dir1, s1 := testServer(t) defer os.RemoveAll(dir1) @@ -217,3 +373,62 @@ func TestSessionEndpoint_NodeSessions(t *testing.T) { } } } + +func TestSessionEndpoint_DeleteNodeSessions(t *testing.T) { + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + client := rpcClient(t, s1) + defer client.Close() + + testutil.WaitForLeader(t, client.Call, "dc1") + + s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + s1.fsm.State().EnsureNode(1, structs.Node{"bar", "127.0.0.1"}) + ids := []string{} + for i := 0; i < 10; i++ { + arg := structs.SessionRequest{ + Datacenter: "dc1", + Op: structs.SessionCreate, + Session: structs.Session{ + Node: "bar", + Behavior: structs.SessionKeysDelete, + }, + } + if i < 5 { + arg.Session.Node = "foo" + } + var out string + if err := client.Call("Session.Apply", &arg, &out); err != nil { + t.Fatalf("err: %v", err) + } + if i < 5 { + ids = append(ids, out) + } + } + + getR := structs.NodeSpecificRequest{ + Datacenter: "dc1", + Node: "foo", + } + var sessions structs.IndexedSessions + if err := client.Call("Session.NodeSessions", &getR, &sessions); err != nil { + t.Fatalf("err: %v", err) + } + + if sessions.Index == 0 { + t.Fatalf("Bad: %v", sessions) + } + if len(sessions.Sessions) != 5 { + t.Fatalf("Bad: %v", sessions.Sessions) + } + for i := 0; i < len(sessions.Sessions); i++ { + s := sessions.Sessions[i] + if !strContains(ids, s.ID) { + t.Fatalf("bad: %v", s) + } + if s.Node != "foo" { + t.Fatalf("bad: %v", s) + } + } +} diff --git a/consul/state_store.go b/consul/state_store.go index 5bbadd4238..bc0da3639f 100644 --- a/consul/state_store.go +++ b/consul/state_store.go @@ -1327,6 +1327,11 @@ func (s *StateStore) SessionCreate(index uint64, session *structs.Session) error return fmt.Errorf("Missing Session ID") } + // make sure we have a default set for session.Behavior + if session.Behavior == "" { + session.Behavior = structs.SessionKeysRelease + } + // Assign the create index session.CreateIndex = index @@ -1454,7 +1459,7 @@ func (s *StateStore) SessionDestroy(index uint64, id string) error { } defer tx.Abort() - log.Printf("[DEBUG] consul.state: Invalidating session %s due to session destroy", + s.logger.Printf("[DEBUG] consul.state: Invalidating session %s due to session destroy", id) if err := s.invalidateSession(index, tx, id); err != nil { return err @@ -1471,7 +1476,7 @@ func (s *StateStore) invalidateNode(index uint64, tx *MDBTxn, node string) error } for _, sess := range sessions { session := sess.(*structs.Session).ID - log.Printf("[DEBUG] consul.state: Invalidating session %s due to node '%s' invalidation", + s.logger.Printf("[DEBUG] consul.state: Invalidating session %s due to node '%s' invalidation", session, node) if err := s.invalidateSession(index, tx, session); err != nil { return err @@ -1489,7 +1494,7 @@ func (s *StateStore) invalidateCheck(index uint64, tx *MDBTxn, node, check strin } for _, sc := range sessionChecks { session := sc.(*sessionCheck).Session - log.Printf("[DEBUG] consul.state: Invalidating session %s due to check '%s' invalidation", + s.logger.Printf("[DEBUG] consul.state: Invalidating session %s due to check '%s' invalidation", session, check) if err := s.invalidateSession(index, tx, session); err != nil { return err @@ -1513,15 +1518,23 @@ func (s *StateStore) invalidateSession(index uint64, tx *MDBTxn, id string) erro } session := res[0].(*structs.Session) - // Enforce the MaxLockDelay - delay := session.LockDelay - if delay > structs.MaxLockDelay { - delay = structs.MaxLockDelay - } + if session.Behavior == structs.SessionKeysDelete { + // delete the keys held by the session + if err := s.deleteKeys(index, tx, id); err != nil { + return err + } - // Invalidate any held locks - if err := s.invalidateLocks(index, tx, delay, id); err != nil { - return err + } else { // default to release + // Enforce the MaxLockDelay + delay := session.LockDelay + if delay > structs.MaxLockDelay { + delay = structs.MaxLockDelay + } + + // Invalidate any held locks + if err := s.invalidateLocks(index, tx, delay, id); err != nil { + return err + } } // Nuke the session @@ -1588,6 +1601,23 @@ func (s *StateStore) invalidateLocks(index uint64, tx *MDBTxn, return nil } +// deleteKeys is used to delete all the keys created by a session +// within a given txn. All tables should be locked in the tx. +func (s *StateStore) deleteKeys(index uint64, tx *MDBTxn, id string) error { + num, err := s.kvsTable.DeleteTxn(tx, "session", id) + if err != nil { + return err + } + + if num > 0 { + if err := s.kvsTable.SetLastIndexTxn(tx, index); err != nil { + return err + } + tx.Defer(func() { s.watch[s.kvsTable].Notify() }) + } + return nil +} + // ACLSet is used to create or update an ACL entry func (s *StateStore) ACLSet(index uint64, acl *structs.ACL) error { // Check for an ID diff --git a/consul/state_store_test.go b/consul/state_store_test.go index 9480cf1b13..0dba19abda 100644 --- a/consul/state_store_test.go +++ b/consul/state_store_test.go @@ -2279,6 +2279,53 @@ func TestSessionInvalidate_KeyUnlock(t *testing.T) { } } +func TestSessionInvalidate_KeyDelete(t *testing.T) { + store, err := testStateStore() + if err != nil { + t.Fatalf("err: %v", err) + } + defer store.Close() + + if err := store.EnsureNode(3, structs.Node{"foo", "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + session := &structs.Session{ + ID: generateUUID(), + Node: "foo", + LockDelay: 50 * time.Millisecond, + Behavior: structs.SessionKeysDelete, + } + if err := store.SessionCreate(4, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Lock a key with the session + d := &structs.DirEntry{ + Key: "/bar", + Flags: 42, + Value: []byte("test"), + Session: session.ID, + } + ok, err := store.KVSLock(5, d) + if err != nil { + t.Fatalf("err: %v", err) + } + if !ok { + t.Fatalf("unexpected fail") + } + + // Delete the node + if err := store.DeleteNode(6, "foo"); err != nil { + t.Fatalf("err: %v", err) + } + + // Key should be deleted + _, d2, err := store.KVSGet("/bar") + if d2 != nil { + t.Fatalf("unexpected undeleted key") + } +} + func TestACLSet_Get(t *testing.T) { store, err := testStateStore() if err != nil { diff --git a/consul/structs/structs.go b/consul/structs/structs.go index b1f315271d..ced8567d26 100644 --- a/consul/structs/structs.go +++ b/consul/structs/structs.go @@ -378,6 +378,13 @@ type IndexedKeyList struct { QueryMeta } +type SessionBehavior string + +const ( + SessionKeysRelease SessionBehavior = "release" + SessionKeysDelete = "delete" +) + // Session is used to represent an open session in the KV store. // This issued to associate node checks with acquired locks. type Session struct { @@ -387,6 +394,7 @@ type Session struct { Node string Checks []string LockDelay time.Duration + Behavior SessionBehavior // What to do when session is invalidated } type Sessions []*Session