mirror of https://github.com/hashicorp/consul
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
338 lines
8.8 KiB
338 lines
8.8 KiB
// Copyright (c) HashiCorp, Inc. |
|
// SPDX-License-Identifier: BUSL-1.1 |
|
|
|
package balancer |
|
|
|
import ( |
|
"context" |
|
"fmt" |
|
"math/rand" |
|
"net" |
|
"net/url" |
|
"sort" |
|
"sync" |
|
"sync/atomic" |
|
"testing" |
|
"time" |
|
|
|
"github.com/stretchr/testify/require" |
|
"google.golang.org/grpc" |
|
"google.golang.org/grpc/codes" |
|
"google.golang.org/grpc/credentials/insecure" |
|
"google.golang.org/grpc/resolver" |
|
"google.golang.org/grpc/resolver/manual" |
|
"google.golang.org/grpc/stats" |
|
"google.golang.org/grpc/status" |
|
|
|
"github.com/hashicorp/go-uuid" |
|
|
|
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice" |
|
"github.com/hashicorp/consul/sdk/testutil" |
|
"github.com/hashicorp/consul/sdk/testutil/retry" |
|
) |
|
|
|
func TestBalancer(t *testing.T) { |
|
t.Run("remains pinned to the same server", func(t *testing.T) { |
|
ctx, cancel := context.WithCancel(context.Background()) |
|
t.Cleanup(cancel) |
|
|
|
server1 := runServer(t, "server1") |
|
server2 := runServer(t, "server2") |
|
|
|
target, authority, _ := stubResolver(t, server1, server2) |
|
|
|
balancerBuilder := NewBuilder(authority, testutil.Logger(t)) |
|
balancerBuilder.Register() |
|
t.Cleanup(balancerBuilder.Deregister) |
|
|
|
conn := dial(t, target) |
|
client := testservice.NewSimpleClient(conn) |
|
|
|
var serverName string |
|
for i := 0; i < 5; i++ { |
|
rsp, err := client.Something(ctx, &testservice.Req{}) |
|
require.NoError(t, err) |
|
|
|
if i == 0 { |
|
serverName = rsp.ServerName |
|
} else { |
|
require.Equal(t, serverName, rsp.ServerName) |
|
} |
|
} |
|
|
|
var pinnedServer, otherServer *server |
|
switch serverName { |
|
case server1.name: |
|
pinnedServer, otherServer = server1, server2 |
|
case server2.name: |
|
pinnedServer, otherServer = server2, server1 |
|
} |
|
require.Equal(t, 1, |
|
pinnedServer.openConnections(), |
|
"pinned server should have 1 connection", |
|
) |
|
require.Zero(t, |
|
otherServer.openConnections(), |
|
"other server should have no connections", |
|
) |
|
}) |
|
|
|
t.Run("switches server on-error", func(t *testing.T) { |
|
ctx, cancel := context.WithCancel(context.Background()) |
|
t.Cleanup(cancel) |
|
|
|
server1 := runServer(t, "server1") |
|
server2 := runServer(t, "server2") |
|
|
|
target, authority, _ := stubResolver(t, server1, server2) |
|
|
|
balancerBuilder := NewBuilder(authority, testutil.Logger(t)) |
|
balancerBuilder.Register() |
|
t.Cleanup(balancerBuilder.Deregister) |
|
|
|
conn := dial(t, target) |
|
client := testservice.NewSimpleClient(conn) |
|
|
|
// Figure out which server we're talking to now, and which we should switch to. |
|
rsp, err := client.Something(ctx, &testservice.Req{}) |
|
require.NoError(t, err) |
|
|
|
var initialServer, otherServer *server |
|
switch rsp.ServerName { |
|
case server1.name: |
|
initialServer, otherServer = server1, server2 |
|
case server2.name: |
|
initialServer, otherServer = server2, server1 |
|
} |
|
|
|
// Next request should fail (we don't have retries configured). |
|
initialServer.err = status.Error(codes.ResourceExhausted, "rate limit exceeded") |
|
_, err = client.Something(ctx, &testservice.Req{}) |
|
require.Error(t, err) |
|
|
|
// Following request should succeed (against the other server). |
|
rsp, err = client.Something(ctx, &testservice.Req{}) |
|
require.NoError(t, err) |
|
require.Equal(t, otherServer.name, rsp.ServerName) |
|
|
|
retry.Run(t, func(r *retry.R) { |
|
require.Zero(r, |
|
initialServer.openConnections(), |
|
"connection to previous server should have been torn down", |
|
) |
|
}) |
|
}) |
|
|
|
t.Run("rebalance changes the server", func(t *testing.T) { |
|
ctx, cancel := context.WithCancel(context.Background()) |
|
t.Cleanup(cancel) |
|
|
|
server1 := runServer(t, "server1") |
|
server2 := runServer(t, "server2") |
|
|
|
target, authority, _ := stubResolver(t, server1, server2) |
|
|
|
balancerBuilder := NewBuilder(authority, testutil.Logger(t)) |
|
balancerBuilder.Register() |
|
t.Cleanup(balancerBuilder.Deregister) |
|
|
|
// Provide a custom prioritizer that causes Rebalance to choose whichever |
|
// server didn't get our first request. |
|
var otherServer *server |
|
balancerBuilder.shuffler = func(addrs []resolver.Address) { |
|
sort.Slice(addrs, func(a, b int) bool { |
|
return addrs[a].Addr == otherServer.addr |
|
}) |
|
} |
|
|
|
conn := dial(t, target) |
|
client := testservice.NewSimpleClient(conn) |
|
|
|
// Figure out which server we're talking to now. |
|
rsp, err := client.Something(ctx, &testservice.Req{}) |
|
require.NoError(t, err) |
|
|
|
var initialServer *server |
|
switch rsp.ServerName { |
|
case server1.name: |
|
initialServer, otherServer = server1, server2 |
|
case server2.name: |
|
initialServer, otherServer = server2, server1 |
|
} |
|
|
|
// Trigger a rebalance. |
|
targetURL, err := url.Parse(target) |
|
require.NoError(t, err) |
|
balancerBuilder.Rebalance(resolver.Target{URL: *targetURL}) |
|
|
|
// Following request should hit the other server. |
|
rsp, err = client.Something(ctx, &testservice.Req{}) |
|
require.NoError(t, err) |
|
require.Equal(t, otherServer.name, rsp.ServerName) |
|
|
|
retry.Run(t, func(r *retry.R) { |
|
require.Zero(r, |
|
initialServer.openConnections(), |
|
"connection to previous server should have been torn down", |
|
) |
|
}) |
|
}) |
|
|
|
t.Run("resolver removes the server", func(t *testing.T) { |
|
ctx, cancel := context.WithCancel(context.Background()) |
|
t.Cleanup(cancel) |
|
|
|
server1 := runServer(t, "server1") |
|
server2 := runServer(t, "server2") |
|
|
|
target, authority, res := stubResolver(t, server1, server2) |
|
|
|
balancerBuilder := NewBuilder(authority, testutil.Logger(t)) |
|
balancerBuilder.Register() |
|
t.Cleanup(balancerBuilder.Deregister) |
|
|
|
conn := dial(t, target) |
|
client := testservice.NewSimpleClient(conn) |
|
|
|
// Figure out which server we're talking to now. |
|
rsp, err := client.Something(ctx, &testservice.Req{}) |
|
require.NoError(t, err) |
|
var initialServer, otherServer *server |
|
switch rsp.ServerName { |
|
case server1.name: |
|
initialServer, otherServer = server1, server2 |
|
case server2.name: |
|
initialServer, otherServer = server2, server1 |
|
} |
|
|
|
// Remove the server's address. |
|
res.UpdateState(resolver.State{ |
|
Addresses: []resolver.Address{ |
|
{Addr: otherServer.addr}, |
|
}, |
|
}) |
|
|
|
// Following request should hit the other server. |
|
rsp, err = client.Something(ctx, &testservice.Req{}) |
|
require.NoError(t, err) |
|
require.Equal(t, otherServer.name, rsp.ServerName) |
|
|
|
retry.Run(t, func(r *retry.R) { |
|
require.Zero(r, |
|
initialServer.openConnections(), |
|
"connection to previous server should have been torn down", |
|
) |
|
}) |
|
|
|
// Remove the other server too. |
|
res.UpdateState(resolver.State{ |
|
Addresses: []resolver.Address{}, |
|
}) |
|
|
|
_, err = client.Something(ctx, &testservice.Req{}) |
|
require.Error(t, err) |
|
require.Contains(t, err.Error(), "resolver produced no addresses") |
|
|
|
retry.Run(t, func(r *retry.R) { |
|
require.Zero(r, |
|
otherServer.openConnections(), |
|
"connection to other server should have been torn down", |
|
) |
|
}) |
|
}) |
|
} |
|
|
|
func stubResolver(t *testing.T, servers ...*server) (string, string, *manual.Resolver) { |
|
t.Helper() |
|
|
|
addresses := make([]resolver.Address, len(servers)) |
|
for idx, s := range servers { |
|
addresses[idx] = resolver.Address{Addr: s.addr} |
|
} |
|
|
|
scheme := fmt.Sprintf("consul-%d-%d", time.Now().UnixNano(), rand.Int()) |
|
|
|
r := manual.NewBuilderWithScheme(scheme) |
|
r.InitialState(resolver.State{Addresses: addresses}) |
|
|
|
resolver.Register(r) |
|
|
|
authority, err := uuid.GenerateUUID() |
|
require.NoError(t, err) |
|
|
|
return fmt.Sprintf("%s://%s", scheme, authority), authority, r |
|
} |
|
|
|
func runServer(t *testing.T, name string) *server { |
|
t.Helper() |
|
|
|
lis, err := net.Listen("tcp", "127.0.0.1:0") |
|
require.NoError(t, err) |
|
|
|
s := &server{ |
|
name: name, |
|
addr: lis.Addr().String(), |
|
} |
|
|
|
gs := grpc.NewServer( |
|
grpc.StatsHandler(s), |
|
) |
|
testservice.RegisterSimpleServer(gs, s) |
|
go gs.Serve(lis) |
|
|
|
var once sync.Once |
|
s.shutdown = func() { once.Do(gs.Stop) } |
|
t.Cleanup(s.shutdown) |
|
|
|
return s |
|
} |
|
|
|
type server struct { |
|
name string |
|
addr string |
|
err error |
|
|
|
c int32 |
|
shutdown func() |
|
} |
|
|
|
func (s *server) openConnections() int { return int(atomic.LoadInt32(&s.c)) } |
|
|
|
func (*server) HandleRPC(context.Context, stats.RPCStats) {} |
|
func (*server) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { return ctx } |
|
func (*server) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { return ctx } |
|
|
|
func (s *server) HandleConn(_ context.Context, cs stats.ConnStats) { |
|
switch cs.(type) { |
|
case *stats.ConnBegin: |
|
atomic.AddInt32(&s.c, 1) |
|
case *stats.ConnEnd: |
|
atomic.AddInt32(&s.c, -1) |
|
} |
|
} |
|
|
|
func (*server) Flow(*testservice.Req, testservice.Simple_FlowServer) error { return nil } |
|
|
|
func (s *server) Something(context.Context, *testservice.Req) (*testservice.Resp, error) { |
|
if s.err != nil { |
|
return nil, s.err |
|
} |
|
return &testservice.Resp{ServerName: s.name}, nil |
|
} |
|
|
|
func dial(t *testing.T, target string) *grpc.ClientConn { |
|
conn, err := grpc.Dial( |
|
target, |
|
grpc.WithTransportCredentials(insecure.NewCredentials()), |
|
grpc.WithDefaultServiceConfig( |
|
fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, BuilderName), |
|
), |
|
) |
|
t.Cleanup(func() { |
|
if err := conn.Close(); err != nil { |
|
t.Logf("error closing connection: %v", err) |
|
} |
|
}) |
|
require.NoError(t, err) |
|
return conn |
|
}
|
|
|