mirror of https://github.com/hashicorp/consul
agent: make the RPC endpoint overwrite mechanism more transparent
This patch hides the RPC handler overwrite mechanism from the rest of the code so that it works in all cases and that there is no cooperation required from the tested code, i.e. we can drop a.getEndpoint().pull/3163/head
parent
e15f9f9d90
commit
2b41f2e3a3
|
@ -174,7 +174,7 @@ func (m *aclManager) lookupACL(a *Agent, id string) (acl.ACL, error) {
|
|||
args.ETag = cached.ETag
|
||||
}
|
||||
var reply structs.ACLPolicy
|
||||
err := a.RPC(a.getEndpoint("ACL")+".GetPolicy", &args, &reply)
|
||||
err := a.RPC("ACL.GetPolicy", &args, &reply)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), aclDisabled) {
|
||||
a.logger.Printf("[DEBUG] agent: ACLs disabled on servers, will check again after %s", a.config.ACLDisabledTTL)
|
||||
|
|
|
@ -47,7 +47,7 @@ func TestACL_Version8(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockServer{}
|
||||
if err := a.InjectEndpoint("ACL", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("ACL", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -70,7 +70,7 @@ func TestACL_Disabled(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockServer{}
|
||||
if err := a.InjectEndpoint("ACL", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("ACL", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -123,7 +123,7 @@ func TestACL_Special_IDs(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockServer{}
|
||||
if err := a.InjectEndpoint("ACL", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("ACL", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -176,7 +176,7 @@ func TestACL_Down_Deny(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockServer{}
|
||||
if err := a.InjectEndpoint("ACL", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("ACL", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -206,7 +206,7 @@ func TestACL_Down_Allow(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockServer{}
|
||||
if err := a.InjectEndpoint("ACL", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("ACL", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -236,7 +236,7 @@ func TestACL_Down_Extend(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockServer{}
|
||||
if err := a.InjectEndpoint("ACL", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("ACL", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -313,7 +313,7 @@ func TestACL_Cache(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockServer{}
|
||||
if err := a.InjectEndpoint("ACL", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("ACL", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -495,7 +495,7 @@ func TestACL_vetServiceRegister(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockServer{catalogPolicy}
|
||||
if err := a.InjectEndpoint("ACL", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("ACL", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -541,7 +541,7 @@ func TestACL_vetServiceUpdate(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockServer{catalogPolicy}
|
||||
if err := a.InjectEndpoint("ACL", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("ACL", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -577,7 +577,7 @@ func TestACL_vetCheckRegister(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockServer{catalogPolicy}
|
||||
if err := a.InjectEndpoint("ACL", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("ACL", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -660,7 +660,7 @@ func TestACL_vetCheckUpdate(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockServer{catalogPolicy}
|
||||
if err := a.InjectEndpoint("ACL", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("ACL", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -716,7 +716,7 @@ func TestACL_filterMembers(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockServer{catalogPolicy}
|
||||
if err := a.InjectEndpoint("ACL", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("ACL", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -752,7 +752,7 @@ func TestACL_filterServices(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockServer{catalogPolicy}
|
||||
if err := a.InjectEndpoint("ACL", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("ACL", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -783,7 +783,7 @@ func TestACL_filterChecks(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockServer{catalogPolicy}
|
||||
if err := a.InjectEndpoint("ACL", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("ACL", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -146,9 +145,9 @@ type Agent struct {
|
|||
// attempts.
|
||||
retryJoinCh chan error
|
||||
|
||||
// endpoints lets you override RPC endpoints for testing. Not all
|
||||
// agent methods use this, so use with care and never override
|
||||
// outside of a unit test.
|
||||
// endpoints maps unique RPC endpoint names to common ones
|
||||
// to allow overriding of RPC handlers since the golang
|
||||
// net/rpc server does not allow this.
|
||||
endpoints map[string]string
|
||||
endpointsLock sync.RWMutex
|
||||
|
||||
|
@ -1068,9 +1067,34 @@ LOAD:
|
|||
return nil
|
||||
}
|
||||
|
||||
// RegisterEndpoint registers a handler for the consul RPC server
|
||||
// under a unique name while making it accessible under the provided
|
||||
// name. This allows overwriting handlers for the golang net/rpc
|
||||
// service which does not allow this.
|
||||
func (a *Agent) RegisterEndpoint(name string, handler interface{}) error {
|
||||
srv, ok := a.delegate.(*consul.Server)
|
||||
if !ok {
|
||||
panic("agent must be a server")
|
||||
}
|
||||
realname := fmt.Sprintf("%s-%d", name, time.Now().UnixNano())
|
||||
a.endpointsLock.Lock()
|
||||
a.endpoints[name] = realname
|
||||
a.endpointsLock.Unlock()
|
||||
return srv.RegisterEndpoint(realname, handler)
|
||||
}
|
||||
|
||||
// RPC is used to make an RPC call to the Consul servers
|
||||
// This allows the agent to implement the Consul.Interface
|
||||
func (a *Agent) RPC(method string, args interface{}, reply interface{}) error {
|
||||
a.endpointsLock.Lock()
|
||||
// fast path: only translate if there are overrides
|
||||
if len(a.endpoints) > 0 {
|
||||
p := strings.SplitN(method, ".", 2)
|
||||
if e := a.endpoints[p[0]]; e != "" {
|
||||
method = e + "." + p[1]
|
||||
}
|
||||
}
|
||||
a.endpointsLock.Unlock()
|
||||
return a.delegate.RPC(method, args, reply)
|
||||
}
|
||||
|
||||
|
@ -2255,37 +2279,6 @@ func (a *Agent) DisableNodeMaintenance() {
|
|||
a.logger.Printf("[INFO] agent: Node left maintenance mode")
|
||||
}
|
||||
|
||||
// InjectEndpoint overrides the given endpoint with a substitute one. Note
|
||||
// that not all agent methods use this mechanism, and that is should only
|
||||
// be used for testing.
|
||||
func (a *Agent) InjectEndpoint(endpoint string, handler interface{}) error {
|
||||
srv, ok := a.delegate.(*consul.Server)
|
||||
if !ok {
|
||||
return fmt.Errorf("agent must be a server")
|
||||
}
|
||||
if err := srv.InjectEndpoint(handler); err != nil {
|
||||
return err
|
||||
}
|
||||
name := reflect.Indirect(reflect.ValueOf(handler)).Type().Name()
|
||||
a.endpointsLock.Lock()
|
||||
a.endpoints[endpoint] = name
|
||||
a.endpointsLock.Unlock()
|
||||
|
||||
a.logger.Printf("[WARN] agent: endpoint injected; this should only be used for testing")
|
||||
return nil
|
||||
}
|
||||
|
||||
// getEndpoint returns the endpoint name to use for the given endpoint,
|
||||
// which may be overridden.
|
||||
func (a *Agent) getEndpoint(endpoint string) string {
|
||||
a.endpointsLock.RLock()
|
||||
defer a.endpointsLock.RUnlock()
|
||||
if override, ok := a.endpoints[endpoint]; ok {
|
||||
return override
|
||||
}
|
||||
return endpoint
|
||||
}
|
||||
|
||||
func (a *Agent) ReloadConfig(newCfg *Config) (bool, error) {
|
||||
var errs error
|
||||
|
||||
|
|
|
@ -977,10 +977,10 @@ func (s *Server) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io
|
|||
return nil
|
||||
}
|
||||
|
||||
// InjectEndpoint is used to substitute an endpoint for testing.
|
||||
func (s *Server) InjectEndpoint(endpoint interface{}) error {
|
||||
// RegisterEndpoint is used to substitute an endpoint for testing.
|
||||
func (s *Server) RegisterEndpoint(name string, handler interface{}) error {
|
||||
s.logger.Printf("[WARN] consul: endpoint injected; this should only be used for testing")
|
||||
return s.rpcServer.Register(endpoint)
|
||||
return s.rpcServer.RegisterName(name, handler)
|
||||
}
|
||||
|
||||
// Stats is used to return statistics for debugging and insight
|
||||
|
|
|
@ -695,10 +695,9 @@ func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, req,
|
|||
// likely work in practice, like 10*maxUDPAnswerLimit which should help
|
||||
// reduce bandwidth if there are thousands of nodes available.
|
||||
|
||||
endpoint := d.agent.getEndpoint(preparedQueryEndpoint)
|
||||
var out structs.PreparedQueryExecuteResponse
|
||||
RPC:
|
||||
if err := d.agent.RPC(endpoint+".Execute", &args, &out); err != nil {
|
||||
if err := d.agent.RPC("PreparedQuery.Execute", &args, &out); err != nil {
|
||||
// If they give a bogus query name, treat that as a name error,
|
||||
// not a full on server error. We have to use a string compare
|
||||
// here since the RPC layer loses the type information.
|
||||
|
|
|
@ -3932,7 +3932,7 @@ func TestDNS_PreparedQuery_AllowStale(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockPreparedQuery{}
|
||||
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -4013,7 +4013,7 @@ func TestDNS_PreparedQuery_AgentSource(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockPreparedQuery{}
|
||||
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -37,8 +37,7 @@ func (s *HTTPServer) preparedQueryCreate(resp http.ResponseWriter, req *http.Req
|
|||
}
|
||||
|
||||
var reply string
|
||||
endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
|
||||
if err := s.agent.RPC(endpoint+".Apply", &args, &reply); err != nil {
|
||||
if err := s.agent.RPC("PreparedQuery.Apply", &args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return preparedQueryCreateResponse{reply}, nil
|
||||
|
@ -52,8 +51,7 @@ func (s *HTTPServer) preparedQueryList(resp http.ResponseWriter, req *http.Reque
|
|||
}
|
||||
|
||||
var reply structs.IndexedPreparedQueries
|
||||
endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
|
||||
if err := s.agent.RPC(endpoint+".List", &args, &reply); err != nil {
|
||||
if err := s.agent.RPC("PreparedQuery.List", &args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -110,8 +108,7 @@ func (s *HTTPServer) preparedQueryExecute(id string, resp http.ResponseWriter, r
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
|
||||
if err := s.agent.RPC(endpoint+".Execute", &args, &reply); err != nil {
|
||||
if err := s.agent.RPC("PreparedQuery.Execute", &args, &reply); err != nil {
|
||||
// We have to check the string since the RPC sheds
|
||||
// the specific error type.
|
||||
if err.Error() == consul.ErrQueryNotFound.Error() {
|
||||
|
@ -155,8 +152,7 @@ func (s *HTTPServer) preparedQueryExplain(id string, resp http.ResponseWriter, r
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExplainResponse
|
||||
endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
|
||||
if err := s.agent.RPC(endpoint+".Explain", &args, &reply); err != nil {
|
||||
if err := s.agent.RPC("PreparedQuery.Explain", &args, &reply); err != nil {
|
||||
// We have to check the string since the RPC sheds
|
||||
// the specific error type.
|
||||
if err.Error() == consul.ErrQueryNotFound.Error() {
|
||||
|
@ -179,8 +175,7 @@ func (s *HTTPServer) preparedQueryGet(id string, resp http.ResponseWriter, req *
|
|||
}
|
||||
|
||||
var reply structs.IndexedPreparedQueries
|
||||
endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
|
||||
if err := s.agent.RPC(endpoint+".Get", &args, &reply); err != nil {
|
||||
if err := s.agent.RPC("PreparedQuery.Get", &args, &reply); err != nil {
|
||||
// We have to check the string since the RPC sheds
|
||||
// the specific error type.
|
||||
if err.Error() == consul.ErrQueryNotFound.Error() {
|
||||
|
@ -212,8 +207,7 @@ func (s *HTTPServer) preparedQueryUpdate(id string, resp http.ResponseWriter, re
|
|||
args.Query.ID = id
|
||||
|
||||
var reply string
|
||||
endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
|
||||
if err := s.agent.RPC(endpoint+".Apply", &args, &reply); err != nil {
|
||||
if err := s.agent.RPC("PreparedQuery.Apply", &args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
|
@ -231,8 +225,7 @@ func (s *HTTPServer) preparedQueryDelete(id string, resp http.ResponseWriter, re
|
|||
s.parseToken(req, &args.Token)
|
||||
|
||||
var reply string
|
||||
endpoint := s.agent.getEndpoint(preparedQueryEndpoint)
|
||||
if err := s.agent.RPC(endpoint+".Apply", &args, &reply); err != nil {
|
||||
if err := s.agent.RPC("PreparedQuery.Apply", &args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
|
|
|
@ -74,7 +74,7 @@ func TestPreparedQuery_Create(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockPreparedQuery{}
|
||||
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -159,7 +159,7 @@ func TestPreparedQuery_List(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockPreparedQuery{}
|
||||
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -192,7 +192,7 @@ func TestPreparedQuery_List(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockPreparedQuery{}
|
||||
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -242,7 +242,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockPreparedQuery{}
|
||||
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -275,7 +275,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockPreparedQuery{}
|
||||
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -331,7 +331,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockPreparedQuery{}
|
||||
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -365,7 +365,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockPreparedQuery{}
|
||||
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -415,7 +415,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockPreparedQuery{}
|
||||
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -479,7 +479,7 @@ func TestPreparedQuery_Explain(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockPreparedQuery{}
|
||||
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -552,7 +552,7 @@ func TestPreparedQuery_Get(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockPreparedQuery{}
|
||||
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -617,7 +617,7 @@ func TestPreparedQuery_Update(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockPreparedQuery{}
|
||||
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -695,7 +695,7 @@ func TestPreparedQuery_Delete(t *testing.T) {
|
|||
defer a.Shutdown()
|
||||
|
||||
m := MockPreparedQuery{}
|
||||
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil {
|
||||
if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue