diff --git a/agent/consul/client.go b/agent/consul/client.go index 861d95e53c..c2e0806379 100644 --- a/agent/consul/client.go +++ b/agent/consul/client.go @@ -68,8 +68,7 @@ type Client struct { // from an agent. rpcLimiter atomic.Value - // eventCh is used to receive events from the - // serf cluster in the datacenter + // eventCh is used to receive events from the serf cluster in the datacenter eventCh chan serf.Event // Logger uses the provided LogOutput @@ -108,6 +107,9 @@ func NewClient(config *Config, options ...ConsulOption) (*Client, error) { if flat.logger == nil { return nil, fmt.Errorf("logger is required") } + if flat.router == nil { + return nil, fmt.Errorf("router is required") + } if connPool == nil { connPool = &pool.ConnPool{ @@ -156,23 +158,17 @@ func NewClient(config *Config, options ...ConsulOption) (*Client, error) { } // Initialize the LAN Serf - c.serf, err = c.setupSerf(config.SerfLANConfig, - c.eventCh, serfLANSnapshot) + c.serf, err = c.setupSerf(config.SerfLANConfig, c.eventCh, serfLANSnapshot) if err != nil { c.Shutdown() return nil, fmt.Errorf("Failed to start lan serf: %v", err) } - rpcRouter := flat.router - if rpcRouter == nil { - rpcRouter = router.NewRouter(logger, config.Datacenter, fmt.Sprintf("%s.%s", config.NodeName, config.Datacenter)) - } - - if err := rpcRouter.AddArea(types.AreaLAN, c.serf, c.connPool); err != nil { + if err := flat.router.AddArea(types.AreaLAN, c.serf, c.connPool); err != nil { c.Shutdown() return nil, fmt.Errorf("Failed to add LAN area to the RPC router: %w", err) } - c.router = rpcRouter + c.router = flat.router // Start LAN event handlers after the router is complete since the event // handlers depend on the router and the router depends on Serf. diff --git a/agent/consul/client_test.go b/agent/consul/client_test.go index e9a4c7bdc1..cbbb6c6393 100644 --- a/agent/consul/client_test.go +++ b/agent/consul/client_test.go @@ -2,12 +2,14 @@ package consul import ( "bytes" + "fmt" "net" "os" "sync" "testing" "time" + "github.com/hashicorp/consul/agent/router" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/sdk/freeport" "github.com/hashicorp/consul/sdk/testutil" @@ -75,7 +77,11 @@ func testClientWithConfigWithErr(t *testing.T, cb func(c *Config)) (string, *Cli t.Fatalf("err: %v", err) } - client, err := NewClient(config, WithLogger(logger), WithTLSConfigurator(tlsConf)) + r := router.NewRouter(logger, config.Datacenter, fmt.Sprintf("%s.%s", config.NodeName, config.Datacenter)) + client, err := NewClient(config, + WithLogger(logger), + WithTLSConfigurator(tlsConf), + WithRouter(r)) return dir, client, err } @@ -473,7 +479,11 @@ func newClient(t *testing.T, config *Config) *Client { Level: hclog.Debug, Output: testutil.NewLogBuffer(t), }) - client, err := NewClient(config, WithLogger(logger), WithTLSConfigurator(c)) + r := router.NewRouter(logger, config.Datacenter, fmt.Sprintf("%s.%s", config.NodeName, config.Datacenter)) + client, err := NewClient(config, + WithLogger(logger), + WithTLSConfigurator(c), + WithRouter(r)) require.NoError(t, err, "failed to create client") t.Cleanup(func() { client.Shutdown() diff --git a/agent/consul/leader_test.go b/agent/consul/leader_test.go index f8d5f91c6b..2bd30e3350 100644 --- a/agent/consul/leader_test.go +++ b/agent/consul/leader_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/hashicorp/consul/agent/router" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/token" "github.com/hashicorp/consul/api" @@ -1305,10 +1306,13 @@ func TestLeader_ConfigEntryBootstrap_Fail(t *testing.T) { }) tlsConf, err := tlsutil.NewConfigurator(config.ToTLSUtilConfig(), logger) require.NoError(t, err) + + rpcRouter := router.NewRouter(logger, config.Datacenter, fmt.Sprintf("%s.%s", config.NodeName, config.Datacenter)) srv, err := NewServer(config, WithLogger(logger), WithTokenStore(new(token.Store)), - WithTLSConfigurator(tlsConf)) + WithTLSConfigurator(tlsConf), + WithRouter(rpcRouter)) require.NoError(t, err) defer srv.Shutdown() diff --git a/agent/consul/server.go b/agent/consul/server.go index 2a496d9624..7aacdfb9a4 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -331,7 +331,6 @@ func NewServer(config *Config, options ...ConsulOption) (*Server, error) { tokens := flat.tokens tlsConfigurator := flat.tlsConfigurator connPool := flat.connPool - rpcRouter := flat.router if err := config.CheckProtocolVersion(); err != nil { return nil, err @@ -345,6 +344,9 @@ func NewServer(config *Config, options ...ConsulOption) (*Server, error) { if logger == nil { return nil, fmt.Errorf("logger is required") } + if flat.router == nil { + return nil, fmt.Errorf("router is required") + } // Check if TLS is enabled if config.CAFile != "" || config.CAPath != "" { @@ -388,10 +390,6 @@ func NewServer(config *Config, options ...ConsulOption) (*Server, error) { serverLogger := logger.NamedIntercept(logging.ConsulServer) loggers := newLoggerStore(serverLogger) - if rpcRouter == nil { - rpcRouter = router.NewRouter(serverLogger, config.Datacenter, fmt.Sprintf("%s.%s", config.NodeName, config.Datacenter)) - } - // Create server. s := &Server{ config: config, @@ -403,7 +401,7 @@ func NewServer(config *Config, options ...ConsulOption) (*Server, error) { loggers: loggers, leaveCh: make(chan struct{}), reconcileCh: make(chan serf.Member, reconcileChSize), - router: rpcRouter, + router: flat.router, rpcServer: rpc.NewServer(), insecureRPCServer: rpc.NewServer(), tlsConfigurator: tlsConfigurator, diff --git a/agent/consul/server_test.go b/agent/consul/server_test.go index 95069c9f70..00a7e4ea13 100644 --- a/agent/consul/server_test.go +++ b/agent/consul/server_test.go @@ -17,6 +17,7 @@ import ( "github.com/google/tcpproxy" "github.com/hashicorp/consul/agent/connect/ca" + "github.com/hashicorp/consul/agent/router" "github.com/hashicorp/consul/ipaddr" "github.com/hashicorp/memberlist" @@ -301,10 +302,13 @@ func newServer(t *testing.T, c *Config) (*Server, error) { if err != nil { return nil, err } + + rpcRouter := router.NewRouter(logger, c.Datacenter, fmt.Sprintf("%s.%s", c.NodeName, c.Datacenter)) srv, err := NewServer(c, WithLogger(logger), WithTokenStore(new(token.Store)), - WithTLSConfigurator(tlsConf)) + WithTLSConfigurator(tlsConf), + WithRouter(rpcRouter)) if err != nil { return nil, err } @@ -1491,10 +1495,13 @@ func TestServer_CALogging(t *testing.T) { c, err := tlsutil.NewConfigurator(conf1.ToTLSUtilConfig(), logger) require.NoError(t, err) + rpcRouter := router.NewRouter(logger, "dc1", fmt.Sprintf("%s.%s", "nodename", "dc1")) + s1, err := NewServer(conf1, WithLogger(logger), WithTokenStore(new(token.Store)), - WithTLSConfigurator(c)) + WithTLSConfigurator(c), + WithRouter(rpcRouter)) if err != nil { t.Fatalf("err: %v", err) }