// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package serverdiscovery

import (
	"context"
	"errors"
	"io"
	"testing"
	"time"

	mock "github.com/stretchr/testify/mock"
	"github.com/stretchr/testify/require"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"

	acl "github.com/hashicorp/consul/acl"
	resolver "github.com/hashicorp/consul/acl/resolver"
	"github.com/hashicorp/consul/agent/consul/autopilotevents"
	"github.com/hashicorp/consul/agent/consul/stream"
	external "github.com/hashicorp/consul/agent/grpc-external"
	"github.com/hashicorp/consul/agent/grpc-external/testutils"
	"github.com/hashicorp/consul/agent/structs"
	"github.com/hashicorp/consul/proto-public/pbserverdiscovery"
	"github.com/hashicorp/consul/proto/private/prototest"
	"github.com/hashicorp/consul/sdk/testutil"
)

const testACLToken = "eb61f1ed-65a4-4da6-8d3d-0564bd16c965"

func TestWatchServers_StreamLifeCycle(t *testing.T) {
	// The flow for this test is roughly:
	//
	// 1. Open a WatchServers stream
	// 2. Observe the snapshot message is sent back through
	//    the stream.
	// 3. Publish an event that changes to 2 servers.
	// 4. See the corresponding message sent back through the stream.
	// 5. Send a NewCloseSubscriptionEvent for the token secret.
	// 6. See that a new snapshot is taken and the corresponding message
	//    gets sent back. If there were multiple subscribers for the topic
	//    then this should not happen. However with the current EventPublisher
	//    implementation, whenever the last subscriber for a topic has its
	//    subscription closed then the publisher will delete the whole topic
	//    buffer. When that happens, resubscribing will see no snapshot
	//    cache, or latest event in the buffer and force creating a new snapshot.
	// 7. Publish another event to move to 3 servers.
	// 8. Ensure that the message gets sent through the stream. Also
	//    this will validate that no other 1 or 2 server event is
	//    seen after stream reinitialization.

	srv1 := autopilotevents.ReadyServerInfo{
		ID:      "9aeb73f6-e83e-43c1-bdc9-ca5e43efe3e4",
		Address: "198.18.0.1",
		Version: "1.12.0",
	}
	srv2 := autopilotevents.ReadyServerInfo{
		ID:      "eec8721f-c42b-48da-a5a5-07565158015e",
		Address: "198.18.0.2",
		Version: "1.12.3",
	}
	srv3 := autopilotevents.ReadyServerInfo{
		ID:      "256796f2-3a38-4f80-8cef-375c3cb3aa1f",
		Address: "198.18.0.3",
		Version: "1.12.3",
	}

	oneServerEventPayload := autopilotevents.EventPayloadReadyServers{srv1}
	twoServerEventPayload := autopilotevents.EventPayloadReadyServers{srv1, srv2}
	threeServerEventPayload := autopilotevents.EventPayloadReadyServers{srv1, srv2, srv3}

	oneServerResponse := &pbserverdiscovery.WatchServersResponse{
		Servers: []*pbserverdiscovery.Server{
			{
				Id:      srv1.ID,
				Address: srv1.Address,
				Version: srv1.Version,
			},
		},
	}

	twoServerResponse := &pbserverdiscovery.WatchServersResponse{
		Servers: []*pbserverdiscovery.Server{
			{
				Id:      srv1.ID,
				Address: srv1.Address,
				Version: srv1.Version,
			},
			{
				Id:      srv2.ID,
				Address: srv2.Address,
				Version: srv2.Version,
			},
		},
	}

	threeServerResponse := &pbserverdiscovery.WatchServersResponse{
		Servers: []*pbserverdiscovery.Server{
			{
				Id:      srv1.ID,
				Address: srv1.Address,
				Version: srv1.Version,
			},
			{
				Id:      srv2.ID,
				Address: srv2.Address,
				Version: srv2.Version,
			},
			{
				Id:      srv3.ID,
				Address: srv3.Address,
				Version: srv3.Version,
			},
		},
	}

	// setup the event publisher and snapshot handler
	handler, publisher := setupPublisher(t)
	// we only expect this to be called once. For the rest of the
	// test we ought to be able to resume the stream.
	handler.expect(testACLToken, 0, 1, oneServerEventPayload)
	handler.expect(testACLToken, 2, 3, twoServerEventPayload)

	// setup the mock ACLResolver and its expectations
	// 2 times authorization should succeed and the third should fail.
	resolver := newMockACLResolver(t)
	resolver.On("ResolveTokenAndDefaultMeta", testACLToken, mock.Anything, mock.Anything).
		Return(testutils.ACLNoPermissions(t), nil).Twice()

	// add the token to the requests context
	options := structs.QueryOptions{Token: testACLToken}
	ctx, err := external.ContextWithQueryOptions(context.Background(), options)
	require.NoError(t, err)

	// setup the server
	server := NewServer(Config{
		Publisher:   publisher,
		Logger:      testutil.Logger(t),
		ACLResolver: resolver,
	})

	// Run the server and get a test client for it
	client := testClient(t, server)

	// 1. Open the WatchServers stream
	serverStream, err := client.WatchServers(ctx, &pbserverdiscovery.WatchServersRequest{Wan: false})
	require.NoError(t, err)

	rspCh := handleReadyServersStream(t, serverStream)

	// 2. Observe the snapshot message is sent back through the stream.
	rsp := mustGetServers(t, rspCh)
	require.NotNil(t, rsp)
	prototest.AssertDeepEqual(t, oneServerResponse, rsp)

	// 3. Publish an event that changes to 2 servers.
	publisher.Publish([]stream.Event{
		{
			Topic:   autopilotevents.EventTopicReadyServers,
			Index:   2,
			Payload: twoServerEventPayload,
		},
	})

	// 4. See the corresponding message sent back through the stream.
	rsp = mustGetServers(t, rspCh)
	require.NotNil(t, rsp)
	prototest.AssertDeepEqual(t, twoServerResponse, rsp)

	// 5. Send a NewCloseSubscriptionEvent for the token secret.
	publisher.Publish([]stream.Event{
		stream.NewCloseSubscriptionEvent([]string{testACLToken}),
	})

	// 6. Observe another snapshot message
	rsp = mustGetServers(t, rspCh)
	require.NotNil(t, rsp)
	prototest.AssertDeepEqual(t, twoServerResponse, rsp)

	// 7. Publish another event to move to 3 servers.
	publisher.Publish([]stream.Event{
		{
			Topic:   autopilotevents.EventTopicReadyServers,
			Index:   4,
			Payload: threeServerEventPayload,
		},
	})

	// 8. Ensure that the message gets sent through the stream. Also
	//    this will validate that no other 1 or 2 server event is
	//    seen after stream reinitialization.
	rsp = mustGetServers(t, rspCh)
	require.NotNil(t, rsp)
	prototest.AssertDeepEqual(t, threeServerResponse, rsp)
}

func TestWatchServers_ACLToken_AnonymousToken(t *testing.T) {
	// setup the event publisher and snapshot handler
	_, publisher := setupPublisher(t)

	resolver := newMockACLResolver(t)
	resolver.On("ResolveTokenAndDefaultMeta", testACLToken, mock.Anything, mock.Anything).
		Return(testutils.ACLAnonymous(t), nil).Once()

	// add the token to the requests context
	options := structs.QueryOptions{Token: testACLToken}
	ctx, err := external.ContextWithQueryOptions(context.Background(), options)
	require.NoError(t, err)

	// setup the server
	server := NewServer(Config{
		Publisher:   publisher,
		Logger:      testutil.Logger(t),
		ACLResolver: resolver,
	})

	// Run the server and get a test client for it
	client := testClient(t, server)

	// 1. Open the WatchServers stream
	serverStream, err := client.WatchServers(ctx, &pbserverdiscovery.WatchServersRequest{Wan: false})
	require.NoError(t, err)
	rspCh := handleReadyServersStream(t, serverStream)

	// Expect to get an Unauthenticated error immediately.
	err = mustGetError(t, rspCh)
	require.Equal(t, codes.Unauthenticated.String(), status.Code(err).String())
}

func TestWatchServers_ACLToken_Unauthenticated(t *testing.T) {
	// setup the event publisher and snapshot handler
	_, publisher := setupPublisher(t)

	aclResolver := newMockACLResolver(t)
	aclResolver.On("ResolveTokenAndDefaultMeta", testACLToken, mock.Anything, mock.Anything).
		Return(resolver.Result{}, acl.ErrNotFound).Once()

	// add the token to the requests context
	options := structs.QueryOptions{Token: testACLToken}
	ctx, err := external.ContextWithQueryOptions(context.Background(), options)
	require.NoError(t, err)

	// setup the server
	server := NewServer(Config{
		Publisher:   publisher,
		Logger:      testutil.Logger(t),
		ACLResolver: aclResolver,
	})

	// Run the server and get a test client for it
	client := testClient(t, server)

	// 1. Open the WatchServers stream
	serverStream, err := client.WatchServers(ctx, &pbserverdiscovery.WatchServersRequest{Wan: false})
	require.NoError(t, err)
	rspCh := handleReadyServersStream(t, serverStream)

	// Expect to get an Unauthenticated error immediately.
	err = mustGetError(t, rspCh)
	require.Equal(t, codes.Unauthenticated.String(), status.Code(err).String())
}

func handleReadyServersStream(t *testing.T, stream pbserverdiscovery.ServerDiscoveryService_WatchServersClient) <-chan serversOrError {
	t.Helper()

	rspCh := make(chan serversOrError)
	go func() {
		for {
			rsp, err := stream.Recv()
			if errors.Is(err, io.EOF) ||
				errors.Is(err, context.Canceled) ||
				errors.Is(err, context.DeadlineExceeded) {
				return
			}
			rspCh <- serversOrError{
				rsp: rsp,
				err: err,
			}
		}
	}()
	return rspCh
}

func mustGetServers(t *testing.T, ch <-chan serversOrError) *pbserverdiscovery.WatchServersResponse {
	t.Helper()

	select {
	case rsp := <-ch:
		require.NoError(t, rsp.err)
		return rsp.rsp
	case <-time.After(1 * time.Second):
		t.Fatal("timeout waiting for WatchServersResponse")
		return nil
	}
}

func mustGetError(t *testing.T, ch <-chan serversOrError) error {
	t.Helper()

	select {
	case rsp := <-ch:
		require.Error(t, rsp.err)
		return rsp.err
	case <-time.After(1 * time.Second):
		t.Fatal("timeout waiting for WatchServersResponse")
		return nil
	}
}

type serversOrError struct {
	rsp *pbserverdiscovery.WatchServersResponse
	err error
}