Browse Source

Adds a slightly more flexible mock system so we can test DNS.

pull/1389/head
James Phillips 9 years ago
parent
commit
5e7523ea4b
  1. 34
      command/agent/agent.go
  2. 14
      command/agent/prepared_query_endpoint.go
  3. 24
      command/agent/prepared_query_endpoint_test.go

34
command/agent/agent.go

@ -9,6 +9,7 @@ import (
"net"
"os"
"path/filepath"
"reflect"
"regexp"
"strconv"
"sync"
@ -104,6 +105,11 @@ type Agent struct {
shutdown bool
shutdownCh chan struct{}
shutdownLock sync.Mutex
// endpoints lets you override RPC endpoints for testing. Not all
// agent methods use this, so use with care and never override
// outside of a unit test.
endpoints map[string]string
}
// Create is used to create a new Agent. Returns
@ -158,6 +164,7 @@ func Create(config *Config, logOutput io.Writer) (*Agent, error) {
eventCh: make(chan serf.UserEvent, 1024),
eventBuf: make([]*UserEvent, 256),
shutdownCh: make(chan struct{}),
endpoints: make(map[string]string),
}
// Initialize the local state
@ -1456,3 +1463,30 @@ func (a *Agent) DisableNodeMaintenance() {
a.RemoveCheck(nodeMaintCheckID, true)
a.logger.Printf("[INFO] agent: Node left maintenance mode")
}
// InjectEndpoint overrides the given endpoint with a substitute one. Note
// that not all agent methods use this mechanism, and that is should only
// be used for testing.
func (a *Agent) InjectEndpoint(endpoint string, handler interface{}) error {
if a.server == nil {
return fmt.Errorf("agent must be a server")
}
if err := a.server.InjectEndpoint(handler); err != nil {
return err
}
name := reflect.Indirect(reflect.ValueOf(handler)).Type().Name()
a.endpoints[endpoint] = name
a.logger.Printf("[WARN] agent: endpoint injected; this should only be used for testing")
return nil
}
// getEndpoint returns the endpoint name to use for the given endpoint,
// which may be overridden.
func (a *Agent) getEndpoint(endpoint string) string {
if override, ok := a.endpoints[endpoint]; ok {
return override
}
return endpoint
}

14
command/agent/prepared_query_endpoint.go

@ -21,12 +21,7 @@ type preparedQueryCreateResponse struct {
// PreparedQueryGeneral handles all the general prepared query requests.
func (s *HTTPServer) PreparedQueryGeneral(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
return s.preparedQueryGeneral(preparedQueryEndpoint, resp, req)
}
// preparedQueryGeneral is the internal method that does the work on behalf of
// PreparedQueryGeneral. The RPC endpoint is parameterized to ease testing.
func (s *HTTPServer) preparedQueryGeneral(endpoint string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
switch req.Method {
case "POST": // Create a new prepared query.
args := structs.PreparedQueryRequest{
@ -82,12 +77,6 @@ func parseLimit(req *http.Request, limit *int) error {
// PreparedQuerySpecifc handles all the prepared query requests specific to a
// particular query.
func (s *HTTPServer) PreparedQuerySpecific(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
return s.preparedQuerySpecific(preparedQueryEndpoint, resp, req)
}
// preparedQuerySpecific is the internal method that does the work on behalf of
// PreparedQuerySpecific. The RPC endpoint is parameterized to ease testing.
func (s *HTTPServer) preparedQuerySpecific(endpoint string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
id := strings.TrimPrefix(req.URL.Path, "/v1/query/")
execute := false
if strings.HasSuffix(id, preparedQueryExecuteSuffix) {
@ -95,6 +84,7 @@ func (s *HTTPServer) preparedQuerySpecific(endpoint string, resp http.ResponseWr
id = strings.TrimSuffix(id, preparedQueryExecuteSuffix)
}
endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
switch req.Method {
case "GET": // Execute or retrieve a prepared query.
if execute {

24
command/agent/prepared_query_endpoint_test.go

@ -62,7 +62,7 @@ func (m *MockPreparedQuery) Execute(args *structs.PreparedQueryExecuteRequest,
func TestPreparedQuery_Create(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
m := MockPreparedQuery{}
if err := srv.agent.server.InjectEndpoint(&m); err != nil {
if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -126,7 +126,7 @@ func TestPreparedQuery_Create(t *testing.T) {
}
resp := httptest.NewRecorder()
obj, err := srv.preparedQueryGeneral("MockPreparedQuery", resp, req)
obj, err := srv.PreparedQueryGeneral(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -146,7 +146,7 @@ func TestPreparedQuery_Create(t *testing.T) {
func TestPreparedQuery_List(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
m := MockPreparedQuery{}
if err := srv.agent.server.InjectEndpoint(&m); err != nil {
if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -176,7 +176,7 @@ func TestPreparedQuery_List(t *testing.T) {
}
resp := httptest.NewRecorder()
obj, err := srv.preparedQueryGeneral("MockPreparedQuery", resp, req)
obj, err := srv.PreparedQueryGeneral(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -196,7 +196,7 @@ func TestPreparedQuery_List(t *testing.T) {
func TestPreparedQuery_Execute(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
m := MockPreparedQuery{}
if err := srv.agent.server.InjectEndpoint(&m); err != nil {
if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -230,7 +230,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
}
resp := httptest.NewRecorder()
obj, err := srv.preparedQuerySpecific("MockPreparedQuery", resp, req)
obj, err := srv.PreparedQuerySpecific(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -250,7 +250,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
func TestPreparedQuery_Get(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
m := MockPreparedQuery{}
if err := srv.agent.server.InjectEndpoint(&m); err != nil {
if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -281,7 +281,7 @@ func TestPreparedQuery_Get(t *testing.T) {
}
resp := httptest.NewRecorder()
obj, err := srv.preparedQuerySpecific("MockPreparedQuery", resp, req)
obj, err := srv.PreparedQuerySpecific(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -301,7 +301,7 @@ func TestPreparedQuery_Get(t *testing.T) {
func TestPreparedQuery_Update(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
m := MockPreparedQuery{}
if err := srv.agent.server.InjectEndpoint(&m); err != nil {
if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -367,7 +367,7 @@ func TestPreparedQuery_Update(t *testing.T) {
}
resp := httptest.NewRecorder()
_, err = srv.preparedQuerySpecific("MockPreparedQuery", resp, req)
_, err = srv.PreparedQuerySpecific(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -380,7 +380,7 @@ func TestPreparedQuery_Update(t *testing.T) {
func TestPreparedQuery_Delete(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
m := MockPreparedQuery{}
if err := srv.agent.server.InjectEndpoint(&m); err != nil {
if err := srv.agent.InjectEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}
@ -418,7 +418,7 @@ func TestPreparedQuery_Delete(t *testing.T) {
}
resp := httptest.NewRecorder()
_, err = srv.preparedQuerySpecific("MockPreparedQuery", resp, req)
_, err = srv.PreparedQuerySpecific(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}

Loading…
Cancel
Save