diff --git a/agent/coordinate_endpoint.go b/agent/coordinate_endpoint.go index 1a16964fe2..4d51a9b0fb 100644 --- a/agent/coordinate_endpoint.go +++ b/agent/coordinate_endpoint.go @@ -9,12 +9,16 @@ import ( "github.com/hashicorp/consul/agent/structs" ) -// coordinateDisabled handles all the endpoints when coordinates are not enabled, -// returning an error message. -func coordinateDisabled(resp http.ResponseWriter, req *http.Request) (interface{}, error) { +// checkCoordinateDisabled will return a standard response if coordinates are +// disabled. This returns true if they are disabled and we should not continue. +func (s *HTTPServer) checkCoordinateDisabled(resp http.ResponseWriter, req *http.Request) bool { + if !s.agent.config.DisableCoordinates { + return false + } + resp.WriteHeader(http.StatusUnauthorized) fmt.Fprint(resp, "Coordinate support disabled") - return nil, nil + return true } // sorter wraps a coordinate list and implements the sort.Interface to sort by @@ -41,6 +45,9 @@ func (s *sorter) Less(i, j int) bool { // CoordinateDatacenters returns the WAN nodes in each datacenter, along with // raw network coordinates. func (s *HTTPServer) CoordinateDatacenters(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkCoordinateDisabled(resp, req) { + return nil, nil + } if req.Method != "GET" { return nil, MethodNotAllowedError{req.Method, []string{"GET"}} } @@ -70,6 +77,9 @@ func (s *HTTPServer) CoordinateDatacenters(resp http.ResponseWriter, req *http.R // CoordinateNodes returns the LAN nodes in the given datacenter, along with // raw network coordinates. func (s *HTTPServer) CoordinateNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkCoordinateDisabled(resp, req) { + return nil, nil + } if req.Method != "GET" { return nil, MethodNotAllowedError{req.Method, []string{"GET"}} } @@ -92,6 +102,9 @@ func (s *HTTPServer) CoordinateNodes(resp http.ResponseWriter, req *http.Request // CoordinateNode returns the LAN node in the given datacenter, along with // raw network coordinates. func (s *HTTPServer) CoordinateNode(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkCoordinateDisabled(resp, req) { + return nil, nil + } if req.Method != "GET" { return nil, MethodNotAllowedError{req.Method, []string{"GET"}} } @@ -141,6 +154,9 @@ func filterCoordinates(req *http.Request, in structs.Coordinates) structs.Coordi // CoordinateUpdate inserts or updates the LAN coordinate of a node. func (s *HTTPServer) CoordinateUpdate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if s.checkCoordinateDisabled(resp, req) { + return nil, nil + } if req.Method != "PUT" { return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} } diff --git a/agent/coordinate_endpoint_test.go b/agent/coordinate_endpoint_test.go index 09001ccabe..deb812a9b9 100644 --- a/agent/coordinate_endpoint_test.go +++ b/agent/coordinate_endpoint_test.go @@ -1,8 +1,10 @@ package agent import ( + "fmt" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -10,6 +12,40 @@ import ( "github.com/hashicorp/serf/coordinate" ) +func TestCoordinate_Disabled_Response(t *testing.T) { + t.Parallel() + a := NewTestAgent(t.Name(), ` + disable_coordinates = true +`) + defer a.Shutdown() + + tests := []func(resp http.ResponseWriter, req *http.Request) (interface{}, error){ + a.srv.CoordinateDatacenters, + a.srv.CoordinateNodes, + a.srv.CoordinateNode, + a.srv.CoordinateUpdate, + } + for i, tt := range tests { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + req, _ := http.NewRequest("PUT", "/should/not/care", nil) + resp := httptest.NewRecorder() + obj, err := tt(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if obj != nil { + t.Fatalf("bad: %#v", obj) + } + if got, want := resp.Code, http.StatusUnauthorized; got != want { + t.Fatalf("got %d want %d", got, want) + } + if !strings.Contains(resp.Body.String(), "Coordinate support disabled") { + t.Fatalf("bad: %#v", resp) + } + }) + } +} + func TestCoordinate_Datacenters(t *testing.T) { t.Parallel() a := NewTestAgent(t.Name(), "") diff --git a/agent/http.go b/agent/http.go index ca3085fd42..ed6f350922 100644 --- a/agent/http.go +++ b/agent/http.go @@ -113,17 +113,10 @@ func (s *HTTPServer) handler(enableDebug bool) http.Handler { handleFuncMetrics("/v1/catalog/services", s.wrap(s.CatalogServices)) handleFuncMetrics("/v1/catalog/service/", s.wrap(s.CatalogServiceNodes)) handleFuncMetrics("/v1/catalog/node/", s.wrap(s.CatalogNodeServices)) - if !s.agent.config.DisableCoordinates { - handleFuncMetrics("/v1/coordinate/datacenters", s.wrap(s.CoordinateDatacenters)) - handleFuncMetrics("/v1/coordinate/nodes", s.wrap(s.CoordinateNodes)) - handleFuncMetrics("/v1/coordinate/node/", s.wrap(s.CoordinateNode)) - handleFuncMetrics("/v1/coordinate/update", s.wrap(s.CoordinateUpdate)) - } else { - handleFuncMetrics("/v1/coordinate/datacenters", s.wrap(coordinateDisabled)) - handleFuncMetrics("/v1/coordinate/nodes", s.wrap(coordinateDisabled)) - handleFuncMetrics("/v1/coordinate/node/", s.wrap(coordinateDisabled)) - handleFuncMetrics("/v1/coordinate/update", s.wrap(coordinateDisabled)) - } + handleFuncMetrics("/v1/coordinate/datacenters", s.wrap(s.CoordinateDatacenters)) + handleFuncMetrics("/v1/coordinate/nodes", s.wrap(s.CoordinateNodes)) + handleFuncMetrics("/v1/coordinate/node/", s.wrap(s.CoordinateNode)) + handleFuncMetrics("/v1/coordinate/update", s.wrap(s.CoordinateUpdate)) handleFuncMetrics("/v1/event/fire/", s.wrap(s.EventFire)) handleFuncMetrics("/v1/event/list", s.wrap(s.EventList)) handleFuncMetrics("/v1/health/node/", s.wrap(s.HealthNodeChecks))