diff --git a/consul/fsm.go b/consul/fsm.go index 9810a289be..654c5f16c2 100644 --- a/consul/fsm.go +++ b/consul/fsm.go @@ -85,29 +85,11 @@ func (c *consulFSM) decodeRegister(buf []byte, index uint64) interface{} { } func (c *consulFSM) applyRegister(req *structs.RegisterRequest, index uint64) interface{} { - // Ensure the node - node := structs.Node{req.Node, req.Address} - if err := c.state.EnsureNode(index, node); err != nil { - c.logger.Printf("[INFO] consul.fsm: EnsureNode failed: %v", err) + // Apply all updates in a single transaction + if err := c.state.EnsureRegistration(index, req); err != nil { + c.logger.Printf("[INFO] consul.fsm: EnsureRegistration failed: %v", err) return err } - - // Ensure the service if provided - if req.Service != nil { - if err := c.state.EnsureService(index, req.Node, req.Service); err != nil { - c.logger.Printf("[INFO] consul.fsm: EnsureService failed: %v", err) - return err - } - } - - // Ensure the check if provided - if req.Check != nil { - if err := c.state.EnsureCheck(index, req.Check); err != nil { - c.logger.Printf("[INFO] consul.fsm: EnsureCheck failed: %v", err) - return err - } - } - return nil } diff --git a/consul/issue_test.go b/consul/issue_test.go new file mode 100644 index 0000000000..6bcd66e8a3 --- /dev/null +++ b/consul/issue_test.go @@ -0,0 +1,77 @@ +package consul + +import ( + "os" + "reflect" + "testing" + + "github.com/hashicorp/consul/consul/structs" +) + +// Testing for GH-300 and GH-279 +func TestHealthCheckRace(t *testing.T) { + fsm, err := NewFSM(os.Stderr) + if err != nil { + t.Fatalf("err: %v", err) + } + defer fsm.Close() + state := fsm.State() + + req := structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + Service: &structs.NodeService{ + ID: "db", + Service: "db", + }, + Check: &structs.HealthCheck{ + Node: "foo", + CheckID: "db", + Name: "db connectivity", + Status: structs.HealthPassing, + ServiceID: "db", + }, + } + buf, err := structs.Encode(structs.RegisterRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + log := makeLog(buf) + log.Index = 10 + resp := fsm.Apply(log) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + // Verify the index + idx, out1 := state.CheckServiceNodes("db") + if idx != 10 { + t.Fatalf("Bad index") + } + + // Update the check state + req.Check.Status = structs.HealthCritical + buf, err = structs.Encode(structs.RegisterRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + log = makeLog(buf) + log.Index = 20 + resp = fsm.Apply(log) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + // Verify the index changed + idx, out2 := state.CheckServiceNodes("db") + if idx != 20 { + t.Fatalf("Bad index") + } + + if reflect.DeepEqual(out1, out2) { + t.Fatalf("match: %#v %#v", *out1[0].Checks[0], *out2[0].Checks[0]) + } +} diff --git a/consul/state_store.go b/consul/state_store.go index 80870ac5a8..420d9bc6ac 100644 --- a/consul/state_store.go +++ b/consul/state_store.go @@ -375,6 +375,39 @@ func (s *StateStore) QueryTables(q string) MDBTables { return s.queryTables[q] } +// EnsureRegistration is used to make sure a node, service, and check registration +// is performed within a single transaction to avoid race conditions on state updates. +func (s *StateStore) EnsureRegistration(index uint64, req *structs.RegisterRequest) error { + tx, err := s.tables.StartTxn(false) + if err != nil { + panic(fmt.Errorf("Failed to start txn: %v", err)) + } + defer tx.Abort() + + // Ensure the node + node := structs.Node{req.Node, req.Address} + if err := s.ensureNodeTxn(index, node, tx); err != nil { + return err + } + + // Ensure the service if provided + if req.Service != nil { + if err := s.ensureServiceTxn(index, req.Node, req.Service, tx); err != nil { + return err + } + } + + // Ensure the check if provided + if req.Check != nil { + if err := s.ensureCheckTxn(index, req.Check, tx); err != nil { + return err + } + } + + // Commit as one unit + return tx.Commit() +} + // EnsureNode is used to ensure a given node exists, with the provided address func (s *StateStore) EnsureNode(index uint64, node structs.Node) error { tx, err := s.nodeTable.StartTxn(false, nil) diff --git a/consul/state_store_test.go b/consul/state_store_test.go index f0cdae90f3..2c2a0fd012 100644 --- a/consul/state_store_test.go +++ b/consul/state_store_test.go @@ -13,6 +13,57 @@ func testStateStore() (*StateStore, error) { return NewStateStore(os.Stderr) } +func TestEnsureRegistration(t *testing.T) { + store, err := testStateStore() + if err != nil { + t.Fatalf("err: %v", err) + } + defer store.Close() + + reg := &structs.RegisterRequest{ + Node: "foo", + Address: "127.0.0.1", + Service: &structs.NodeService{"api", "api", nil, 5000}, + Check: &structs.HealthCheck{ + Node: "foo", + CheckID: "api", + Name: "Can connect", + Status: structs.HealthPassing, + ServiceID: "api", + }, + } + + if err := store.EnsureRegistration(13, reg); err != nil { + t.Fatalf("err: %v") + } + + idx, found, addr := store.GetNode("foo") + if idx != 13 || !found || addr != "127.0.0.1" { + t.Fatalf("Bad: %v %v %v", idx, found, addr) + } + + idx, services := store.NodeServices("foo") + if idx != 13 { + t.Fatalf("bad: %v", idx) + } + + entry, ok := services.Services["api"] + if !ok { + t.Fatalf("missing api: %#v", services) + } + if entry.Tags != nil || entry.Port != 5000 { + t.Fatalf("Bad entry: %#v", entry) + } + + idx, checks := store.NodeChecks("foo") + if idx != 13 { + t.Fatalf("bad: %v", idx) + } + if len(checks) != 1 { + t.Fatalf("check: %#v", checks) + } +} + func TestEnsureNode(t *testing.T) { store, err := testStateStore() if err != nil {