diff --git a/command/agent/agent.go b/command/agent/agent.go index 14db1ea3de..a9e3363285 100644 --- a/command/agent/agent.go +++ b/command/agent/agent.go @@ -9,6 +9,7 @@ import ( "net" "os" "path/filepath" + "reflect" "regexp" "strconv" "sync" @@ -104,6 +105,11 @@ type Agent struct { shutdown bool shutdownCh chan struct{} shutdownLock sync.Mutex + + // endpoints lets you override RPC endpoints for testing. Not all + // agent methods use this, so use with care and never override + // outside of a unit test. + endpoints map[string]string } // Create is used to create a new Agent. Returns @@ -158,6 +164,7 @@ func Create(config *Config, logOutput io.Writer) (*Agent, error) { eventCh: make(chan serf.UserEvent, 1024), eventBuf: make([]*UserEvent, 256), shutdownCh: make(chan struct{}), + endpoints: make(map[string]string), } // Initialize the local state @@ -1456,3 +1463,30 @@ func (a *Agent) DisableNodeMaintenance() { a.RemoveCheck(nodeMaintCheckID, true) a.logger.Printf("[INFO] agent: Node left maintenance mode") } + +// InjectEndpoint overrides the given endpoint with a substitute one. Note +// that not all agent methods use this mechanism, and that is should only +// be used for testing. +func (a *Agent) InjectEndpoint(endpoint string, handler interface{}) error { + if a.server == nil { + return fmt.Errorf("agent must be a server") + } + + if err := a.server.InjectEndpoint(handler); err != nil { + return err + } + name := reflect.Indirect(reflect.ValueOf(handler)).Type().Name() + a.endpoints[endpoint] = name + + a.logger.Printf("[WARN] agent: endpoint injected; this should only be used for testing") + return nil +} + +// getEndpoint returns the endpoint name to use for the given endpoint, +// which may be overridden. +func (a *Agent) getEndpoint(endpoint string) string { + if override, ok := a.endpoints[endpoint]; ok { + return override + } + return endpoint +} diff --git a/command/agent/prepared_query_endpoint.go b/command/agent/prepared_query_endpoint.go index a27978c89a..5d9c07212d 100644 --- a/command/agent/prepared_query_endpoint.go +++ b/command/agent/prepared_query_endpoint.go @@ -21,12 +21,7 @@ type preparedQueryCreateResponse struct { // PreparedQueryGeneral handles all the general prepared query requests. func (s *HTTPServer) PreparedQueryGeneral(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - return s.preparedQueryGeneral(preparedQueryEndpoint, resp, req) -} - -// preparedQueryGeneral is the internal method that does the work on behalf of -// PreparedQueryGeneral. The RPC endpoint is parameterized to ease testing. -func (s *HTTPServer) preparedQueryGeneral(endpoint string, resp http.ResponseWriter, req *http.Request) (interface{}, error) { + endpoint := s.agent.getEndpoint(preparedQueryEndpoint) switch req.Method { case "POST": // Create a new prepared query. args := structs.PreparedQueryRequest{ @@ -82,12 +77,6 @@ func parseLimit(req *http.Request, limit *int) error { // PreparedQuerySpecifc handles all the prepared query requests specific to a // particular query. func (s *HTTPServer) PreparedQuerySpecific(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - return s.preparedQuerySpecific(preparedQueryEndpoint, resp, req) -} - -// preparedQuerySpecific is the internal method that does the work on behalf of -// PreparedQuerySpecific. The RPC endpoint is parameterized to ease testing. -func (s *HTTPServer) preparedQuerySpecific(endpoint string, resp http.ResponseWriter, req *http.Request) (interface{}, error) { id := strings.TrimPrefix(req.URL.Path, "/v1/query/") execute := false if strings.HasSuffix(id, preparedQueryExecuteSuffix) { @@ -95,6 +84,7 @@ func (s *HTTPServer) preparedQuerySpecific(endpoint string, resp http.ResponseWr id = strings.TrimSuffix(id, preparedQueryExecuteSuffix) } + endpoint := s.agent.getEndpoint(preparedQueryEndpoint) switch req.Method { case "GET": // Execute or retrieve a prepared query. if execute { diff --git a/command/agent/prepared_query_endpoint_test.go b/command/agent/prepared_query_endpoint_test.go index 01384d0284..23a39b196c 100644 --- a/command/agent/prepared_query_endpoint_test.go +++ b/command/agent/prepared_query_endpoint_test.go @@ -62,7 +62,7 @@ func (m *MockPreparedQuery) Execute(args *structs.PreparedQueryExecuteRequest, func TestPreparedQuery_Create(t *testing.T) { httpTest(t, func(srv *HTTPServer) { m := MockPreparedQuery{} - if err := srv.agent.server.InjectEndpoint(&m); err != nil { + if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) } @@ -126,7 +126,7 @@ func TestPreparedQuery_Create(t *testing.T) { } resp := httptest.NewRecorder() - obj, err := srv.preparedQueryGeneral("MockPreparedQuery", resp, req) + obj, err := srv.PreparedQueryGeneral(resp, req) if err != nil { t.Fatalf("err: %v", err) } @@ -146,7 +146,7 @@ func TestPreparedQuery_Create(t *testing.T) { func TestPreparedQuery_List(t *testing.T) { httpTest(t, func(srv *HTTPServer) { m := MockPreparedQuery{} - if err := srv.agent.server.InjectEndpoint(&m); err != nil { + if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) } @@ -176,7 +176,7 @@ func TestPreparedQuery_List(t *testing.T) { } resp := httptest.NewRecorder() - obj, err := srv.preparedQueryGeneral("MockPreparedQuery", resp, req) + obj, err := srv.PreparedQueryGeneral(resp, req) if err != nil { t.Fatalf("err: %v", err) } @@ -196,7 +196,7 @@ func TestPreparedQuery_List(t *testing.T) { func TestPreparedQuery_Execute(t *testing.T) { httpTest(t, func(srv *HTTPServer) { m := MockPreparedQuery{} - if err := srv.agent.server.InjectEndpoint(&m); err != nil { + if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) } @@ -230,7 +230,7 @@ func TestPreparedQuery_Execute(t *testing.T) { } resp := httptest.NewRecorder() - obj, err := srv.preparedQuerySpecific("MockPreparedQuery", resp, req) + obj, err := srv.PreparedQuerySpecific(resp, req) if err != nil { t.Fatalf("err: %v", err) } @@ -250,7 +250,7 @@ func TestPreparedQuery_Execute(t *testing.T) { func TestPreparedQuery_Get(t *testing.T) { httpTest(t, func(srv *HTTPServer) { m := MockPreparedQuery{} - if err := srv.agent.server.InjectEndpoint(&m); err != nil { + if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) } @@ -281,7 +281,7 @@ func TestPreparedQuery_Get(t *testing.T) { } resp := httptest.NewRecorder() - obj, err := srv.preparedQuerySpecific("MockPreparedQuery", resp, req) + obj, err := srv.PreparedQuerySpecific(resp, req) if err != nil { t.Fatalf("err: %v", err) } @@ -301,7 +301,7 @@ func TestPreparedQuery_Get(t *testing.T) { func TestPreparedQuery_Update(t *testing.T) { httpTest(t, func(srv *HTTPServer) { m := MockPreparedQuery{} - if err := srv.agent.server.InjectEndpoint(&m); err != nil { + if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) } @@ -367,7 +367,7 @@ func TestPreparedQuery_Update(t *testing.T) { } resp := httptest.NewRecorder() - _, err = srv.preparedQuerySpecific("MockPreparedQuery", resp, req) + _, err = srv.PreparedQuerySpecific(resp, req) if err != nil { t.Fatalf("err: %v", err) } @@ -380,7 +380,7 @@ func TestPreparedQuery_Update(t *testing.T) { func TestPreparedQuery_Delete(t *testing.T) { httpTest(t, func(srv *HTTPServer) { m := MockPreparedQuery{} - if err := srv.agent.server.InjectEndpoint(&m); err != nil { + if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil { t.Fatalf("err: %v", err) } @@ -418,7 +418,7 @@ func TestPreparedQuery_Delete(t *testing.T) { } resp := httptest.NewRecorder() - _, err = srv.preparedQuerySpecific("MockPreparedQuery", resp, req) + _, err = srv.PreparedQuerySpecific(resp, req) if err != nil { t.Fatalf("err: %v", err) }