// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package resource

import (
	"context"
	"errors"
	"io"
	"testing"
	"time"

	"github.com/hashicorp/consul/acl"
	"github.com/hashicorp/consul/agent/grpc-external/testutils"
	"github.com/hashicorp/consul/internal/resource"
	"github.com/hashicorp/consul/internal/resource/demo"
	"github.com/hashicorp/consul/proto-public/pbresource"
	"github.com/hashicorp/consul/proto/private/prototest"

	"github.com/stretchr/testify/mock"
	"github.com/stretchr/testify/require"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"google.golang.org/protobuf/proto"
)

func TestWatchList_InputValidation(t *testing.T) {
	server := testServer(t)
	client := testClient(t, server)

	demo.RegisterTypes(server.Registry)

	testCases := map[string]func(*pbresource.WatchListRequest){
		"no type":    func(req *pbresource.WatchListRequest) { req.Type = nil },
		"no tenancy": func(req *pbresource.WatchListRequest) { req.Tenancy = nil },
		"partitioned type provides non-empty namespace": func(req *pbresource.WatchListRequest) {
			req.Type = demo.TypeV1RecordLabel
			req.Tenancy.Namespace = "bad"
		},
	}
	for desc, modFn := range testCases {
		t.Run(desc, func(t *testing.T) {
			req := &pbresource.WatchListRequest{
				Type:    demo.TypeV2Album,
				Tenancy: resource.DefaultNamespacedTenancy(),
			}
			modFn(req)

			stream, err := client.WatchList(testContext(t), req)
			require.NoError(t, err)

			_, err = stream.Recv()
			require.Error(t, err)
			require.Equal(t, codes.InvalidArgument.String(), status.Code(err).String())
		})
	}
}

func TestWatchList_TypeNotFound(t *testing.T) {
	t.Parallel()

	server := testServer(t)
	client := testClient(t, server)

	stream, err := client.WatchList(context.Background(), &pbresource.WatchListRequest{
		Type:       demo.TypeV2Artist,
		Tenancy:    resource.DefaultNamespacedTenancy(),
		NamePrefix: "",
	})
	require.NoError(t, err)
	rspCh := handleResourceStream(t, stream)

	err = mustGetError(t, rspCh)
	require.Equal(t, codes.InvalidArgument.String(), status.Code(err).String())
	require.Contains(t, err.Error(), "resource type demo.v2.Artist not registered")
}

func TestWatchList_GroupVersionMatches(t *testing.T) {
	t.Parallel()

	server := testServer(t)
	client := testClient(t, server)
	demo.RegisterTypes(server.Registry)
	ctx := context.Background()

	// create a watch
	stream, err := client.WatchList(ctx, &pbresource.WatchListRequest{
		Type:       demo.TypeV2Artist,
		Tenancy:    resource.DefaultNamespacedTenancy(),
		NamePrefix: "",
	})
	require.NoError(t, err)
	rspCh := handleResourceStream(t, stream)

	artist, err := demo.GenerateV2Artist()
	require.NoError(t, err)

	// insert and verify upsert event received
	r1, err := server.Backend.WriteCAS(ctx, artist)
	require.NoError(t, err)
	rsp := mustGetResource(t, rspCh)
	require.Equal(t, pbresource.WatchEvent_OPERATION_UPSERT, rsp.Operation)
	prototest.AssertDeepEqual(t, r1, rsp.Resource)

	// update and verify upsert event received
	r2 := modifyArtist(t, r1)
	r2, err = server.Backend.WriteCAS(ctx, r2)
	require.NoError(t, err)
	rsp = mustGetResource(t, rspCh)
	require.Equal(t, pbresource.WatchEvent_OPERATION_UPSERT, rsp.Operation)
	prototest.AssertDeepEqual(t, r2, rsp.Resource)

	// delete and verify delete event received
	err = server.Backend.DeleteCAS(ctx, r2.Id, r2.Version)
	require.NoError(t, err)
	rsp = mustGetResource(t, rspCh)
	require.Equal(t, pbresource.WatchEvent_OPERATION_DELETE, rsp.Operation)
}

func TestWatchList_Tenancy_Defaults_And_Normalization(t *testing.T) {
	// Test units of tenancy get lowercased and defaulted correctly when empty.
	for desc, tc := range wildcardTenancyCases() {
		t.Run(desc, func(t *testing.T) {
			ctx := context.Background()
			server := testServer(t)
			client := testClient(t, server)
			demo.RegisterTypes(server.Registry)

			// Create a watch.
			stream, err := client.WatchList(ctx, &pbresource.WatchListRequest{
				Type:       tc.typ,
				Tenancy:    tc.tenancy,
				NamePrefix: "",
			})
			require.NoError(t, err)
			rspCh := handleResourceStream(t, stream)

			// Testcase will pick one of recordLabel or artist based on scope of type.
			recordLabel, err := demo.GenerateV1RecordLabel("LooneyTunes")
			require.NoError(t, err)
			artist, err := demo.GenerateV2Artist()
			require.NoError(t, err)

			// Create and verify upsert event received.
			recordLabel, err = server.Backend.WriteCAS(ctx, recordLabel)
			require.NoError(t, err)
			artist, err = server.Backend.WriteCAS(ctx, artist)
			require.NoError(t, err)

			var expected *pbresource.Resource
			switch {
			case proto.Equal(tc.typ, demo.TypeV1RecordLabel):
				expected = recordLabel
			case proto.Equal(tc.typ, demo.TypeV2Artist):
				expected = artist
			default:
				require.Fail(t, "unsupported type", tc.typ)
			}

			rsp := mustGetResource(t, rspCh)
			require.Equal(t, pbresource.WatchEvent_OPERATION_UPSERT, rsp.Operation)
			prototest.AssertDeepEqual(t, expected, rsp.Resource)
		})

	}
}

func TestWatchList_GroupVersionMismatch(t *testing.T) {
	// Given a watch on TypeArtistV1 that only differs from TypeArtistV2 by GroupVersion
	// When a resource of TypeArtistV2 is created/updated/deleted
	// Then no watch events should be emitted
	t.Parallel()

	server := testServer(t)
	demo.RegisterTypes(server.Registry)
	client := testClient(t, server)
	ctx := context.Background()

	// create a watch for TypeArtistV1
	stream, err := client.WatchList(ctx, &pbresource.WatchListRequest{
		Type:       demo.TypeV1Artist,
		Tenancy:    resource.DefaultNamespacedTenancy(),
		NamePrefix: "",
	})
	require.NoError(t, err)
	rspCh := handleResourceStream(t, stream)

	artist, err := demo.GenerateV2Artist()
	require.NoError(t, err)

	// insert
	r1, err := server.Backend.WriteCAS(ctx, artist)
	require.NoError(t, err)

	// update
	r2 := clone(r1)
	r2, err = server.Backend.WriteCAS(ctx, r2)
	require.NoError(t, err)

	// delete
	err = server.Backend.DeleteCAS(ctx, r2.Id, r2.Version)
	require.NoError(t, err)

	// verify no events received
	mustGetNoResource(t, rspCh)
}

// N.B. Uses key ACLs for now. See demo.RegisterTypes()
func TestWatchList_ACL_ListDenied(t *testing.T) {
	t.Parallel()

	// deny all
	rspCh, _ := roundTripACL(t, testutils.ACLNoPermissions(t))

	// verify key:list denied
	err := mustGetError(t, rspCh)
	require.Error(t, err)
	require.Equal(t, codes.PermissionDenied.String(), status.Code(err).String())
	require.Contains(t, err.Error(), "lacks permission 'key:list'")
}

// N.B. Uses key ACLs for now. See demo.RegisterTypes()
func TestWatchList_ACL_ListAllowed_ReadDenied(t *testing.T) {
	t.Parallel()

	// allow list, deny read
	authz := AuthorizerFrom(t, `
		key_prefix "resource/" { policy = "list" }
		key_prefix "resource/demo.v2.Artist/" { policy = "deny" }
		`)
	rspCh, _ := roundTripACL(t, authz)

	// verify resource filtered out by key:read denied, hence no events
	mustGetNoResource(t, rspCh)
}

// N.B. Uses key ACLs for now. See demo.RegisterTypes()
func TestWatchList_ACL_ListAllowed_ReadAllowed(t *testing.T) {
	t.Parallel()

	// allow list, allow read
	authz := AuthorizerFrom(t, `
		key_prefix "resource/" { policy = "list" }
		key_prefix "resource/demo.v2.Artist/" { policy = "read" }
	`)
	rspCh, artist := roundTripACL(t, authz)

	// verify resource not filtered out by acl
	event := mustGetResource(t, rspCh)
	prototest.AssertDeepEqual(t, artist, event.Resource)
}

// roundtrip a WatchList which attempts to stream back a single write event
func roundTripACL(t *testing.T, authz acl.Authorizer) (<-chan resourceOrError, *pbresource.Resource) {
	server := testServer(t)
	client := testClient(t, server)

	mockACLResolver := &MockACLResolver{}
	mockACLResolver.On("ResolveTokenAndDefaultMeta", mock.Anything, mock.Anything, mock.Anything).
		Return(authz, nil)
	server.ACLResolver = mockACLResolver
	demo.RegisterTypes(server.Registry)

	artist, err := demo.GenerateV2Artist()
	require.NoError(t, err)

	stream, err := client.WatchList(testContext(t), &pbresource.WatchListRequest{
		Type:       artist.Id.Type,
		Tenancy:    artist.Id.Tenancy,
		NamePrefix: "",
	})
	require.NoError(t, err)
	rspCh := handleResourceStream(t, stream)

	// induce single watch event
	artist, err = server.Backend.WriteCAS(context.Background(), artist)
	require.NoError(t, err)

	// caller to make assertions on the rspCh and written artist
	return rspCh, artist
}

func mustGetNoResource(t *testing.T, ch <-chan resourceOrError) {
	t.Helper()

	select {
	case rsp := <-ch:
		require.NoError(t, rsp.err)
		require.Nil(t, rsp.rsp, "expected nil response with no error")
	case <-time.After(250 * time.Millisecond):
		return
	}
}

func mustGetResource(t *testing.T, ch <-chan resourceOrError) *pbresource.WatchEvent {
	t.Helper()

	select {
	case rsp := <-ch:
		require.NoError(t, rsp.err)
		return rsp.rsp
	case <-time.After(1 * time.Second):
		t.Fatal("timeout waiting for WatchListResponse")
		return nil
	}
}

func mustGetError(t *testing.T, ch <-chan resourceOrError) error {
	t.Helper()

	select {
	case rsp := <-ch:
		require.Error(t, rsp.err)
		return rsp.err
	case <-time.After(2 * time.Second):
		t.Fatal("timeout waiting for WatchListResponse")
		return nil
	}
}

func handleResourceStream(t *testing.T, stream pbresource.ResourceService_WatchListClient) <-chan resourceOrError {
	t.Helper()

	rspCh := make(chan resourceOrError)
	go func() {
		for {
			rsp, err := stream.Recv()
			if errors.Is(err, io.EOF) ||
				errors.Is(err, context.Canceled) ||
				errors.Is(err, context.DeadlineExceeded) {
				return
			}
			rspCh <- resourceOrError{
				rsp: rsp,
				err: err,
			}
		}
	}()
	return rspCh
}

type resourceOrError struct {
	rsp *pbresource.WatchEvent
	err error
}