diff --git a/agent/grpc/client_test.go b/agent/grpc/client_test.go index d8ea50dd8b..0cc8cb20ff 100644 --- a/agent/grpc/client_test.go +++ b/agent/grpc/client_test.go @@ -10,6 +10,7 @@ import ( "github.com/hashicorp/consul/agent/grpc/internal/testservice" "github.com/hashicorp/consul/agent/grpc/resolver" "github.com/hashicorp/consul/agent/metadata" + "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/stretchr/testify/require" ) @@ -20,8 +21,8 @@ func TestNewDialer(t *testing.T) { func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { count := 4 - cfg := resolver.Config{Datacenter: "dc1", Scheme: newScheme(t.Name())} - res := resolver.NewServerResolverBuilder(cfg, fakeNodes{num: count}) + cfg := resolver.Config{Scheme: newScheme(t.Name())} + res := resolver.NewServerResolverBuilder(cfg) resolver.RegisterWithGRPC(res) pool := NewClientConnPool(res, nil) @@ -41,6 +42,7 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { first, err := client.Something(ctx, &testservice.Req{}) require.NoError(t, err) + res.RemoveServer(&metadata.Server{ID: first.ServerName, Datacenter: "dc1"}) resp, err := client.Something(ctx, &testservice.Req{}) @@ -54,19 +56,56 @@ func newScheme(n string) string { return strings.ToLower(s) } -type fakeNodes struct { - num int -} +func TestClientConnPool_IntegrationWithGRPCResolver_Rebalance(t *testing.T) { + count := 4 + cfg := resolver.Config{Scheme: newScheme(t.Name())} + res := resolver.NewServerResolverBuilder(cfg) + resolver.RegisterWithGRPC(res) + pool := NewClientConnPool(res, nil) -func (n fakeNodes) NumNodes() int { - return n.num + for i := 0; i < count; i++ { + name := fmt.Sprintf("server-%d", i) + srv := newTestServer(t, name, "dc1") + res.AddServer(srv.Metadata()) + t.Cleanup(srv.shutdown) + } + + conn, err := pool.ClientConn("dc1") + require.NoError(t, err) + client := testservice.NewSimpleClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + t.Cleanup(cancel) + + first, err := client.Something(ctx, &testservice.Req{}) + require.NoError(t, err) + + t.Run("rebalance a different DC, does nothing", func(t *testing.T) { + res.NewRebalancer("dc-other")() + + resp, err := client.Something(ctx, &testservice.Req{}) + require.NoError(t, err) + require.Equal(t, resp.ServerName, first.ServerName) + }) + + t.Run("rebalance the dc", func(t *testing.T) { + // Rebalance is random, but if we repeat it a few times it should give us a + // new server. + retry.RunWith(fastRetry, t, func(r *retry.R) { + res.NewRebalancer("dc1")() + + resp, err := client.Something(ctx, &testservice.Req{}) + require.NoError(r, err) + require.NotEqual(r, resp.ServerName, first.ServerName) + }) + }) } func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) { dcs := []string{"dc1", "dc2", "dc3"} - cfg := resolver.Config{Datacenter: "dc1", Scheme: newScheme(t.Name())} - res := resolver.NewServerResolverBuilder(cfg, fakeNodes{num: 1}) + cfg := resolver.Config{Scheme: newScheme(t.Name())} + res := resolver.NewServerResolverBuilder(cfg) resolver.RegisterWithGRPC(res) pool := NewClientConnPool(res, nil) diff --git a/agent/grpc/resolver/resolver.go b/agent/grpc/resolver/resolver.go index 3bf66b74c1..d35def0d21 100644 --- a/agent/grpc/resolver/resolver.go +++ b/agent/grpc/resolver/resolver.go @@ -1,15 +1,12 @@ package resolver import ( - "context" "fmt" "math/rand" "strings" "sync" - "time" "github.com/hashicorp/consul/agent/metadata" - "github.com/hashicorp/consul/agent/router" "google.golang.org/grpc/resolver" ) @@ -26,17 +23,9 @@ func RegisterWithGRPC(b *ServerResolverBuilder) { resolver.Register(b) } -// Nodes provides a count of the number of nodes in the cluster. It is very -// likely implemented by serf to return the number of LAN members. -type Nodes interface { - NumNodes() int -} - // ServerResolverBuilder tracks the current server list and keeps any // ServerResolvers updated when changes occur. type ServerResolverBuilder struct { - // datacenter of the local agent. - datacenter string // scheme used to query the server. Defaults to consul. Used to support // parallel testing because gRPC registers resolvers globally. scheme string @@ -46,8 +35,6 @@ type ServerResolverBuilder struct { // resolvers is an index of connections to the serverResolver which manages // addresses of servers for that connection. resolvers map[resolver.ClientConn]*serverResolver - // nodes provides the number of nodes in the cluster. - nodes Nodes // lock for servers and resolvers. lock sync.RWMutex } @@ -55,86 +42,45 @@ type ServerResolverBuilder struct { var _ resolver.Builder = (*ServerResolverBuilder)(nil) type Config struct { - // Datacenter of the local agent. - Datacenter string // Scheme used to connect to the server. Defaults to consul. Scheme string } -func NewServerResolverBuilder(cfg Config, nodes Nodes) *ServerResolverBuilder { +func NewServerResolverBuilder(cfg Config) *ServerResolverBuilder { if cfg.Scheme == "" { cfg.Scheme = "consul" } return &ServerResolverBuilder{ - scheme: cfg.Scheme, - datacenter: cfg.Datacenter, - nodes: nodes, - servers: make(map[string]*metadata.Server), - resolvers: make(map[resolver.ClientConn]*serverResolver), + scheme: cfg.Scheme, + servers: make(map[string]*metadata.Server), + resolvers: make(map[resolver.ClientConn]*serverResolver), } } -// Run periodically reshuffles the order of server addresses within the -// resolvers to ensure the load is balanced across servers. -// -// TODO: this looks very similar to agent/router.Manager.Start, which is the -// only other caller of ComputeRebalanceTimer. Are the values passed to these -// two functions different enough that we need separate goroutines to rebalance? -// or could we have a single thing handle the timers, and call both rebalance -// functions? -func (s *ServerResolverBuilder) Run(ctx context.Context) { - // Compute the rebalance timer based on the number of local servers and nodes. - rebalanceDuration := router.ComputeRebalanceTimer(s.serversInDC(s.datacenter), s.nodes.NumNodes()) - timer := time.NewTimer(rebalanceDuration) +// Rebalance shuffles the server list for resolvers in all datacenters. +func (s *ServerResolverBuilder) NewRebalancer(dc string) func() { + return func() { + s.lock.RLock() + defer s.lock.RUnlock() - for { - select { - case <-timer.C: - s.rebalanceResolvers() - - // Re-compute the wait duration. - newTimerDuration := router.ComputeRebalanceTimer(s.serversInDC(s.datacenter), s.nodes.NumNodes()) - timer.Reset(newTimerDuration) - case <-ctx.Done(): - timer.Stop() - return + for _, resolver := range s.resolvers { + if resolver.datacenter != dc { + continue + } + // Shuffle the list of addresses using the last list given to the resolver. + resolver.addrLock.Lock() + addrs := resolver.addrs + // TODO: seed this rand, so it is a little more random-like + rand.Shuffle(len(addrs), func(i, j int) { + addrs[i], addrs[j] = addrs[j], addrs[i] + }) + // Pass the shuffled list to the resolver. + resolver.updateAddrsLocked(addrs) + resolver.addrLock.Unlock() } } } -// rebalanceResolvers shuffles the server list for resolvers in all datacenters. -func (s *ServerResolverBuilder) rebalanceResolvers() { - s.lock.RLock() - defer s.lock.RUnlock() - - for _, resolver := range s.resolvers { - // Shuffle the list of addresses using the last list given to the resolver. - resolver.addrLock.Lock() - addrs := resolver.addrs - rand.Shuffle(len(addrs), func(i, j int) { - addrs[i], addrs[j] = addrs[j], addrs[i] - }) - // Pass the shuffled list to the resolver. - resolver.updateAddrsLocked(addrs) - resolver.addrLock.Unlock() - } -} - -// serversInDC returns the number of servers in the given datacenter. -func (s *ServerResolverBuilder) serversInDC(dc string) int { - s.lock.RLock() - defer s.lock.RUnlock() - - var serverCount int - for _, server := range s.servers { - if server.Datacenter == dc { - serverCount++ - } - } - - return serverCount -} - // ServerForAddr returns server metadata for a server with the specified address. func (s *ServerResolverBuilder) ServerForAddr(addr string) (*metadata.Server, error) { s.lock.RLock() diff --git a/agent/router/grpc.go b/agent/router/grpc.go index 0a50992811..c4fe96d25f 100644 --- a/agent/router/grpc.go +++ b/agent/router/grpc.go @@ -2,19 +2,29 @@ package router import "github.com/hashicorp/consul/agent/metadata" -// ServerTracker is a wrapper around consul.ServerResolverBuilder to prevent a -// cyclic import dependency. +// ServerTracker is called when Router is notified of a server being added or +// removed. type ServerTracker interface { + NewRebalancer(dc string) func() AddServer(*metadata.Server) RemoveServer(*metadata.Server) } +// Rebalancer is called periodically to re-order the servers so that the load on the +// servers is evenly balanced. +type Rebalancer func() + // NoOpServerTracker is a ServerTracker that does nothing. Used when gRPC is not // enabled. type NoOpServerTracker struct{} -// AddServer implements ServerTracker +// Rebalance does nothing +func (NoOpServerTracker) NewRebalancer(string) func() { + return func() {} +} + +// AddServer does nothing func (NoOpServerTracker) AddServer(*metadata.Server) {} -// RemoveServer implements ServerTracker +// RemoveServer does nothing func (NoOpServerTracker) RemoveServer(*metadata.Server) {} diff --git a/agent/router/manager.go b/agent/router/manager.go index 2052eb02d7..4aaab97597 100644 --- a/agent/router/manager.go +++ b/agent/router/manager.go @@ -98,6 +98,8 @@ type Manager struct { // client.ConnPool. connPoolPinger Pinger + rebalancer Rebalancer + // serverName has the name of the managers's server. This is used to // short-circuit pinging to itself. serverName string @@ -267,7 +269,7 @@ func (m *Manager) saveServerList(l serverList) { } // New is the only way to safely create a new Manager struct. -func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfCluster, connPoolPinger Pinger, serverName string) (m *Manager) { +func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfCluster, connPoolPinger Pinger, serverName string, rb Rebalancer) (m *Manager) { if logger == nil { logger = hclog.New(&hclog.LoggerOptions{}) } @@ -278,6 +280,7 @@ func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfC m.connPoolPinger = connPoolPinger // can't pass *consul.ConnPool: import cycle m.rebalanceTimer = time.NewTimer(clientRPCMinReuseDuration) m.shutdownCh = shutdownCh + m.rebalancer = rb m.serverName = serverName atomic.StoreInt32(&m.offline, 1) @@ -498,22 +501,17 @@ func (m *Manager) RemoveServer(s *metadata.Server) { func (m *Manager) refreshServerRebalanceTimer() time.Duration { l := m.getServerList() numServers := len(l.servers) - connRebalanceTimeout := ComputeRebalanceTimer(numServers, m.clusterInfo.NumNodes()) - - m.rebalanceTimer.Reset(connRebalanceTimeout) - return connRebalanceTimeout -} - -// ComputeRebalanceTimer returns a time to wait before rebalancing connections given -// a number of servers and LAN nodes. -func ComputeRebalanceTimer(numServers, numLANMembers int) time.Duration { // Limit this connection's life based on the size (and health) of the // cluster. Never rebalance a connection more frequently than // connReuseLowWatermarkDuration, and make sure we never exceed // clusterWideRebalanceConnsPerSec operations/s across numLANMembers. clusterWideRebalanceConnsPerSec := float64(numServers * newRebalanceConnsPerSecPerServer) connReuseLowWatermarkDuration := clientRPCMinReuseDuration + lib.RandomStagger(clientRPCMinReuseDuration/clientRPCJitterFraction) - return lib.RateScaledInterval(clusterWideRebalanceConnsPerSec, connReuseLowWatermarkDuration, numLANMembers) + numLANMembers := m.clusterInfo.NumNodes() + connRebalanceTimeout := lib.RateScaledInterval(clusterWideRebalanceConnsPerSec, connReuseLowWatermarkDuration, numLANMembers) + + m.rebalanceTimer.Reset(connRebalanceTimeout) + return connRebalanceTimeout } // ResetRebalanceTimer resets the rebalance timer. This method exists for @@ -534,6 +532,7 @@ func (m *Manager) Start() { for { select { case <-m.rebalanceTimer.C: + m.rebalancer() m.RebalanceServers() m.refreshServerRebalanceTimer() diff --git a/agent/router/manager_internal_test.go b/agent/router/manager_internal_test.go index 76d9512168..05807e2070 100644 --- a/agent/router/manager_internal_test.go +++ b/agent/router/manager_internal_test.go @@ -54,14 +54,16 @@ func (s *fauxSerf) NumNodes() int { func testManager() (m *Manager) { logger := GetBufferedLogger() shutdownCh := make(chan struct{}) - m = New(logger, shutdownCh, &fauxSerf{numNodes: 16384}, &fauxConnPool{}, "") + m = New(logger, shutdownCh, &fauxSerf{numNodes: 16384}, &fauxConnPool{}, "", noopRebalancer) return m } +func noopRebalancer() {} + func testManagerFailProb(failPct float64) (m *Manager) { logger := GetBufferedLogger() shutdownCh := make(chan struct{}) - m = New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, "") + m = New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, "", noopRebalancer) return m } @@ -300,7 +302,7 @@ func TestManagerInternal_refreshServerRebalanceTimer(t *testing.T) { shutdownCh := make(chan struct{}) for _, s := range clusters { - m := New(logger, shutdownCh, &fauxSerf{numNodes: s.numNodes}, &fauxConnPool{}, "") + m := New(logger, shutdownCh, &fauxSerf{numNodes: s.numNodes}, &fauxConnPool{}, "", noopRebalancer) for i := 0; i < s.numServers; i++ { nodeName := fmt.Sprintf("s%02d", i) m.AddServer(&metadata.Server{Name: nodeName}) diff --git a/agent/router/manager_test.go b/agent/router/manager_test.go index c7e1f299ca..dc3628f1bd 100644 --- a/agent/router/manager_test.go +++ b/agent/router/manager_test.go @@ -57,21 +57,23 @@ func (s *fauxSerf) NumNodes() int { func testManager(t testing.TB) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, "") + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, "", noopRebalancer) return m } +func noopRebalancer() {} + func testManagerFailProb(t testing.TB, failPct float64) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, "") + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, "", noopRebalancer) return m } func testManagerFailAddr(t testing.TB, failAddr net.Addr) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failAddr: failAddr}, "") + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failAddr: failAddr}, "", noopRebalancer) return m } @@ -195,7 +197,7 @@ func TestServers_FindServer(t *testing.T) { func TestServers_New(t *testing.T) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m := router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, "") + m := router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, "", noopRebalancer) if m == nil { t.Fatalf("Manager nil") } diff --git a/agent/router/router.go b/agent/router/router.go index 9694e927db..8244745c3b 100644 --- a/agent/router/router.go +++ b/agent/router/router.go @@ -259,7 +259,8 @@ func (r *Router) maybeInitializeManager(area *areaInfo, dc string) *Manager { } shutdownCh := make(chan struct{}) - manager := New(r.logger, shutdownCh, area.cluster, area.pinger, r.serverName) + rb := r.grpcServerTracker.NewRebalancer(dc) + manager := New(r.logger, shutdownCh, area.cluster, area.pinger, r.serverName, rb) info = &managerInfo{ manager: manager, shutdownCh: shutdownCh, diff --git a/agent/setup.go b/agent/setup.go index d56419680c..454bfa510d 100644 --- a/agent/setup.go +++ b/agent/setup.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/consul" + "github.com/hashicorp/consul/agent/grpc/resolver" "github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/router" "github.com/hashicorp/consul/agent/token" @@ -82,8 +83,10 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer) (BaseDeps, error) d.Cache = cache.New(cfg.Cache) d.ConnPool = newConnPool(cfg, d.Logger, d.TLSConfigurator) - // TODO: set grpcServerTracker, requires serf to be setup before this. - d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), nil) + // TODO(streaming): setConfig.Scheme name for tests + builder := resolver.NewServerResolverBuilder(resolver.Config{}) + resolver.RegisterWithGRPC(builder) + d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), builder) acConf := autoconf.Config{ DirectRPC: d.ConnPool,