// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package balancer

import (
	"context"
	"fmt"
	"math/rand"
	"net"
	"net/url"
	"sort"
	"sync"
	"sync/atomic"
	"testing"
	"time"

	"github.com/stretchr/testify/require"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/grpc/resolver"
	"google.golang.org/grpc/resolver/manual"
	"google.golang.org/grpc/stats"
	"google.golang.org/grpc/status"

	"github.com/hashicorp/go-uuid"

	"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
	"github.com/hashicorp/consul/sdk/testutil"
	"github.com/hashicorp/consul/sdk/testutil/retry"
)

func TestBalancer(t *testing.T) {
	t.Run("remains pinned to the same server", func(t *testing.T) {
		ctx, cancel := context.WithCancel(context.Background())
		t.Cleanup(cancel)

		server1 := runServer(t, "server1")
		server2 := runServer(t, "server2")

		target, authority, _ := stubResolver(t, server1, server2)

		balancerBuilder := NewBuilder(authority, testutil.Logger(t))
		balancerBuilder.Register()
		t.Cleanup(balancerBuilder.Deregister)

		conn := dial(t, target)
		client := testservice.NewSimpleClient(conn)

		var serverName string
		for i := 0; i < 5; i++ {
			rsp, err := client.Something(ctx, &testservice.Req{})
			require.NoError(t, err)

			if i == 0 {
				serverName = rsp.ServerName
			} else {
				require.Equal(t, serverName, rsp.ServerName)
			}
		}

		var pinnedServer, otherServer *server
		switch serverName {
		case server1.name:
			pinnedServer, otherServer = server1, server2
		case server2.name:
			pinnedServer, otherServer = server2, server1
		}
		require.Equal(t, 1,
			pinnedServer.openConnections(),
			"pinned server should have 1 connection",
		)
		require.Zero(t,
			otherServer.openConnections(),
			"other server should have no connections",
		)
	})

	t.Run("switches server on-error", func(t *testing.T) {
		ctx, cancel := context.WithCancel(context.Background())
		t.Cleanup(cancel)

		server1 := runServer(t, "server1")
		server2 := runServer(t, "server2")

		target, authority, _ := stubResolver(t, server1, server2)

		balancerBuilder := NewBuilder(authority, testutil.Logger(t))
		balancerBuilder.Register()
		t.Cleanup(balancerBuilder.Deregister)

		conn := dial(t, target)
		client := testservice.NewSimpleClient(conn)

		// Figure out which server we're talking to now, and which we should switch to.
		rsp, err := client.Something(ctx, &testservice.Req{})
		require.NoError(t, err)

		var initialServer, otherServer *server
		switch rsp.ServerName {
		case server1.name:
			initialServer, otherServer = server1, server2
		case server2.name:
			initialServer, otherServer = server2, server1
		}

		// Next request should fail (we don't have retries configured).
		initialServer.err = status.Error(codes.ResourceExhausted, "rate limit exceeded")
		_, err = client.Something(ctx, &testservice.Req{})
		require.Error(t, err)

		// Following request should succeed (against the other server).
		rsp, err = client.Something(ctx, &testservice.Req{})
		require.NoError(t, err)
		require.Equal(t, otherServer.name, rsp.ServerName)

		retry.Run(t, func(r *retry.R) {
			require.Zero(r,
				initialServer.openConnections(),
				"connection to previous server should have been torn down",
			)
		})
	})

	t.Run("rebalance changes the server", func(t *testing.T) {
		ctx, cancel := context.WithCancel(context.Background())
		t.Cleanup(cancel)

		server1 := runServer(t, "server1")
		server2 := runServer(t, "server2")

		target, authority, _ := stubResolver(t, server1, server2)

		balancerBuilder := NewBuilder(authority, testutil.Logger(t))
		balancerBuilder.Register()
		t.Cleanup(balancerBuilder.Deregister)

		// Provide a custom prioritizer that causes Rebalance to choose whichever
		// server didn't get our first request.
		var otherServer *server
		balancerBuilder.shuffler = func(addrs []resolver.Address) {
			sort.Slice(addrs, func(a, b int) bool {
				return addrs[a].Addr == otherServer.addr
			})
		}

		conn := dial(t, target)
		client := testservice.NewSimpleClient(conn)

		// Figure out which server we're talking to now.
		rsp, err := client.Something(ctx, &testservice.Req{})
		require.NoError(t, err)

		var initialServer *server
		switch rsp.ServerName {
		case server1.name:
			initialServer, otherServer = server1, server2
		case server2.name:
			initialServer, otherServer = server2, server1
		}

		// Trigger a rebalance.
		targetURL, err := url.Parse(target)
		require.NoError(t, err)
		balancerBuilder.Rebalance(resolver.Target{URL: *targetURL})

		// Following request should hit the other server.
		rsp, err = client.Something(ctx, &testservice.Req{})
		require.NoError(t, err)
		require.Equal(t, otherServer.name, rsp.ServerName)

		retry.Run(t, func(r *retry.R) {
			require.Zero(r,
				initialServer.openConnections(),
				"connection to previous server should have been torn down",
			)
		})
	})

	t.Run("resolver removes the server", func(t *testing.T) {
		ctx, cancel := context.WithCancel(context.Background())
		t.Cleanup(cancel)

		server1 := runServer(t, "server1")
		server2 := runServer(t, "server2")

		target, authority, res := stubResolver(t, server1, server2)

		balancerBuilder := NewBuilder(authority, testutil.Logger(t))
		balancerBuilder.Register()
		t.Cleanup(balancerBuilder.Deregister)

		conn := dial(t, target)
		client := testservice.NewSimpleClient(conn)

		// Figure out which server we're talking to now.
		rsp, err := client.Something(ctx, &testservice.Req{})
		require.NoError(t, err)
		var initialServer, otherServer *server
		switch rsp.ServerName {
		case server1.name:
			initialServer, otherServer = server1, server2
		case server2.name:
			initialServer, otherServer = server2, server1
		}

		// Remove the server's address.
		res.UpdateState(resolver.State{
			Addresses: []resolver.Address{
				{Addr: otherServer.addr},
			},
		})

		// Following request should hit the other server.
		rsp, err = client.Something(ctx, &testservice.Req{})
		require.NoError(t, err)
		require.Equal(t, otherServer.name, rsp.ServerName)

		retry.Run(t, func(r *retry.R) {
			require.Zero(r,
				initialServer.openConnections(),
				"connection to previous server should have been torn down",
			)
		})

		// Remove the other server too.
		res.UpdateState(resolver.State{
			Addresses: []resolver.Address{},
		})

		_, err = client.Something(ctx, &testservice.Req{})
		require.Error(t, err)
		require.Contains(t, err.Error(), "resolver produced no addresses")

		retry.Run(t, func(r *retry.R) {
			require.Zero(r,
				otherServer.openConnections(),
				"connection to other server should have been torn down",
			)
		})
	})
}

func stubResolver(t *testing.T, servers ...*server) (string, string, *manual.Resolver) {
	t.Helper()

	addresses := make([]resolver.Address, len(servers))
	for idx, s := range servers {
		addresses[idx] = resolver.Address{Addr: s.addr}
	}

	scheme := fmt.Sprintf("consul-%d-%d", time.Now().UnixNano(), rand.Int())

	r := manual.NewBuilderWithScheme(scheme)
	r.InitialState(resolver.State{Addresses: addresses})

	resolver.Register(r)
	t.Cleanup(func() { resolver.UnregisterForTesting(scheme) })

	authority, err := uuid.GenerateUUID()
	require.NoError(t, err)

	return fmt.Sprintf("%s://%s", scheme, authority), authority, r
}

func runServer(t *testing.T, name string) *server {
	t.Helper()

	lis, err := net.Listen("tcp", "127.0.0.1:0")
	require.NoError(t, err)

	s := &server{
		name: name,
		addr: lis.Addr().String(),
	}

	gs := grpc.NewServer(
		grpc.StatsHandler(s),
	)
	testservice.RegisterSimpleServer(gs, s)
	go gs.Serve(lis)

	var once sync.Once
	s.shutdown = func() { once.Do(gs.Stop) }
	t.Cleanup(s.shutdown)

	return s
}

type server struct {
	name string
	addr string
	err  error

	c        int32
	shutdown func()
}

func (s *server) openConnections() int { return int(atomic.LoadInt32(&s.c)) }

func (*server) HandleRPC(context.Context, stats.RPCStats)                         {}
func (*server) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { return ctx }
func (*server) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context   { return ctx }

func (s *server) HandleConn(_ context.Context, cs stats.ConnStats) {
	switch cs.(type) {
	case *stats.ConnBegin:
		atomic.AddInt32(&s.c, 1)
	case *stats.ConnEnd:
		atomic.AddInt32(&s.c, -1)
	}
}

func (*server) Flow(*testservice.Req, testservice.Simple_FlowServer) error { return nil }

func (s *server) Something(context.Context, *testservice.Req) (*testservice.Resp, error) {
	if s.err != nil {
		return nil, s.err
	}
	return &testservice.Resp{ServerName: s.name}, nil
}

func dial(t *testing.T, target string) *grpc.ClientConn {
	conn, err := grpc.Dial(
		target,
		grpc.WithTransportCredentials(insecure.NewCredentials()),
		grpc.WithDefaultServiceConfig(
			fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, BuilderName),
		),
	)
	t.Cleanup(func() {
		if err := conn.Close(); err != nil {
			t.Logf("error closing connection: %v", err)
		}
	})
	require.NoError(t, err)
	return conn
}