diff --git a/api/agent_test.go b/api/agent_test.go index a6e2fc542b..e27887d6f8 100644 --- a/api/agent_test.go +++ b/api/agent_test.go @@ -1,6 +1,7 @@ package api import ( + "strings" "testing" ) @@ -282,7 +283,7 @@ func TestServiceMaintenance(t *testing.T) { } found := false for _, check := range checks { - if check.ServiceName == "redis" { + if strings.Contains(check.CheckID, "maintenance") { found = true if check.Status != "critical" { t.Fatalf("bad: %#v", checks) @@ -304,7 +305,53 @@ func TestServiceMaintenance(t *testing.T) { t.Fatalf("err: %s", err) } for _, check := range checks { - if check.ServiceID == "redis" { + if strings.Contains(check.CheckID, "maintenance") { + t.Fatalf("should have removed health check") + } + } +} + +func TestNodeMaintenance(t *testing.T) { + c, s := makeClient(t) + defer s.stop() + + agent := c.Agent() + + // Enable maintenance mode + if err := agent.EnableNodeMaintenance(); err != nil { + t.Fatalf("err: %s", err) + } + + // Check that a critical check was added + checks, err := agent.Checks() + if err != nil { + t.Fatalf("err: %s", err) + } + found := false + for _, check := range checks { + if strings.Contains(check.CheckID, "maintenance") { + found = true + if check.Status != "critical" { + t.Fatalf("bad: %#v", checks) + } + } + } + if !found { + t.Fatalf("bad: %#v", checks) + } + + // Disable maintenance mode + if err := agent.DisableNodeMaintenance(); err != nil { + t.Fatalf("err: %s", err) + } + + // Ensure the check was removed + checks, err = agent.Checks() + if err != nil { + t.Fatalf("err: %s", err) + } + for _, check := range checks { + if strings.Contains(check.CheckID, "maintenance") { t.Fatalf("should have removed health check") } }