package connectca import ( "context" "errors" "io" "testing" "time" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" "github.com/hashicorp/go-uuid" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/grpc/public" "github.com/hashicorp/consul/agent/grpc/public/testutils" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/proto-public/pbconnectca" "github.com/hashicorp/consul/sdk/testutil" ) const testACLToken = "acl-token" func TestWatchRoots_ConnectDisabled(t *testing.T) { server := NewServer(Config{ConnectEnabled: false}) // Begin the stream. client := testClient(t, server) stream, err := client.WatchRoots(context.Background(), &emptypb.Empty{}) require.NoError(t, err) rspCh := handleRootsStream(t, stream) err = mustGetError(t, rspCh) require.Equal(t, codes.FailedPrecondition.String(), status.Code(err).String()) require.Contains(t, status.Convert(err).Message(), "Connect") } func TestWatchRoots_Success(t *testing.T) { fsm, publisher := setupFSMAndPublisher(t) // Set the initial roots and CA configuration. rootA := connect.TestCA(t, nil) _, err := fsm.GetStore().CARootSetCAS(1, 0, structs.CARoots{rootA}) require.NoError(t, err) err = fsm.GetStore().CASetConfig(0, &structs.CAConfiguration{ClusterID: "cluster-id"}) require.NoError(t, err) // Mock the ACL Resolver to return an authorizer with `service:write`. aclResolver := &MockACLResolver{} aclResolver.On("ResolveTokenAndDefaultMeta", testACLToken, mock.Anything, mock.Anything). Return(testutils.TestAuthorizer(t), nil) ctx := public.ContextWithToken(context.Background(), testACLToken) server := NewServer(Config{ Publisher: publisher, GetStore: func() StateStore { return fsm.GetStore() }, Logger: testutil.Logger(t), ACLResolver: aclResolver, ConnectEnabled: true, }) // Begin the stream. client := testClient(t, server) stream, err := client.WatchRoots(ctx, &emptypb.Empty{}) require.NoError(t, err) rspCh := handleRootsStream(t, stream) // Expect an initial message containing current roots (provided by the snapshot). roots := mustGetRoots(t, rspCh) require.Equal(t, "cluster-id.consul", roots.TrustDomain) require.Equal(t, rootA.ID, roots.ActiveRootId) require.Len(t, roots.Roots, 1) require.Equal(t, rootA.ID, roots.Roots[0].Id) // Rotate the roots. rootB := connect.TestCA(t, nil) _, err = fsm.GetStore().CARootSetCAS(2, 1, structs.CARoots{rootB}) require.NoError(t, err) // Expect another event containing the new roots. roots = mustGetRoots(t, rspCh) require.Equal(t, "cluster-id.consul", roots.TrustDomain) require.Equal(t, rootB.ID, roots.ActiveRootId) require.Len(t, roots.Roots, 1) require.Equal(t, rootB.ID, roots.Roots[0].Id) } func TestWatchRoots_InvalidACLToken(t *testing.T) { fsm, publisher := setupFSMAndPublisher(t) // Set the initial CA configuration. err := fsm.GetStore().CASetConfig(0, &structs.CAConfiguration{ClusterID: "cluster-id"}) require.NoError(t, err) // Mock the ACL resolver to return ErrNotFound. aclResolver := &MockACLResolver{} aclResolver.On("ResolveTokenAndDefaultMeta", mock.Anything, mock.Anything, mock.Anything). Return(nil, acl.ErrNotFound) ctx := public.ContextWithToken(context.Background(), testACLToken) server := NewServer(Config{ Publisher: publisher, GetStore: func() StateStore { return fsm.GetStore() }, Logger: testutil.Logger(t), ACLResolver: aclResolver, ConnectEnabled: true, }) // Start the stream. client := testClient(t, server) stream, err := client.WatchRoots(ctx, &emptypb.Empty{}) require.NoError(t, err) rspCh := handleRootsStream(t, stream) // Expect to get an Unauthenticated error immediately. err = mustGetError(t, rspCh) require.Equal(t, codes.Unauthenticated.String(), status.Code(err).String()) } func TestWatchRoots_ACLTokenInvalidated(t *testing.T) { fsm, publisher := setupFSMAndPublisher(t) // Set the initial roots and CA configuration. rootA := connect.TestCA(t, nil) _, err := fsm.GetStore().CARootSetCAS(1, 0, structs.CARoots{rootA}) require.NoError(t, err) err = fsm.GetStore().CASetConfig(2, &structs.CAConfiguration{ClusterID: "cluster-id"}) require.NoError(t, err) // Mock the ACL Resolver to return an authorizer with `service:write` the // first two times it is called (initial connect and first re-auth). aclResolver := &MockACLResolver{} aclResolver.On("ResolveTokenAndDefaultMeta", testACLToken, mock.Anything, mock.Anything). Return(testutils.TestAuthorizer(t), nil).Twice() ctx := public.ContextWithToken(context.Background(), testACLToken) server := NewServer(Config{ Publisher: publisher, GetStore: func() StateStore { return fsm.GetStore() }, Logger: testutil.Logger(t), ACLResolver: aclResolver, ConnectEnabled: true, }) // Start the stream. client := testClient(t, server) stream, err := client.WatchRoots(ctx, &emptypb.Empty{}) require.NoError(t, err) rspCh := handleRootsStream(t, stream) // Consume the initial response. mustGetRoots(t, rspCh) // Update the ACL token to cause the subscription to be force-closed. accessorID, err := uuid.GenerateUUID() require.NoError(t, err) err = fsm.GetStore().ACLTokenSet(1, &structs.ACLToken{ AccessorID: accessorID, SecretID: testACLToken, }) require.NoError(t, err) // Update the roots. rootB := connect.TestCA(t, nil) _, err = fsm.GetStore().CARootSetCAS(3, 1, structs.CARoots{rootB}) require.NoError(t, err) // Expect the stream to remain open and to receive the new roots. mustGetRoots(t, rspCh) // Simulate removing the `service:write` permission. aclResolver.On("ResolveTokenAndDefaultMeta", testACLToken, mock.Anything, mock.Anything). Return(acl.DenyAll(), nil) // Update the ACL token to cause the subscription to be force-closed. err = fsm.GetStore().ACLTokenSet(1, &structs.ACLToken{ AccessorID: accessorID, SecretID: testACLToken, }) require.NoError(t, err) // Expect the stream to be terminated. err = mustGetError(t, rspCh) require.Equal(t, codes.PermissionDenied.String(), status.Code(err).String()) } func TestWatchRoots_StateStoreAbandoned(t *testing.T) { fsm, publisher := setupFSMAndPublisher(t) // Set the initial roots and CA configuration. rootA := connect.TestCA(t, nil) _, err := fsm.GetStore().CARootSetCAS(1, 0, structs.CARoots{rootA}) require.NoError(t, err) err = fsm.GetStore().CASetConfig(0, &structs.CAConfiguration{ClusterID: "cluster-a"}) require.NoError(t, err) // Mock the ACL Resolver to return an authorizer with `service:write`. aclResolver := &MockACLResolver{} aclResolver.On("ResolveTokenAndDefaultMeta", testACLToken, mock.Anything, mock.Anything). Return(testutils.TestAuthorizer(t), nil) ctx := public.ContextWithToken(context.Background(), testACLToken) server := NewServer(Config{ Publisher: publisher, GetStore: func() StateStore { return fsm.GetStore() }, Logger: testutil.Logger(t), ACLResolver: aclResolver, ConnectEnabled: true, }) // Begin the stream. client := testClient(t, server) stream, err := client.WatchRoots(ctx, &emptypb.Empty{}) require.NoError(t, err) rspCh := handleRootsStream(t, stream) // Consume the initial roots. mustGetRoots(t, rspCh) // Simulate a snapshot restore. storeB := testStateStore(t, publisher) rootB := connect.TestCA(t, nil) _, err = storeB.CARootSetCAS(1, 0, structs.CARoots{rootB}) require.NoError(t, err) err = storeB.CASetConfig(0, &structs.CAConfiguration{ClusterID: "cluster-b"}) require.NoError(t, err) fsm.ReplaceStore(storeB) // Expect to get the new store's roots. newRoots := mustGetRoots(t, rspCh) require.Equal(t, "cluster-b.consul", newRoots.TrustDomain) require.Len(t, newRoots.Roots, 1) require.Equal(t, rootB.ID, newRoots.ActiveRootId) } func mustGetRoots(t *testing.T, ch <-chan rootsOrError) *pbconnectca.WatchRootsResponse { t.Helper() select { case rsp := <-ch: require.NoError(t, rsp.err) return rsp.rsp case <-time.After(1 * time.Second): t.Fatal("timeout waiting for WatchRootsResponse") return nil } } func mustGetError(t *testing.T, ch <-chan rootsOrError) error { t.Helper() select { case rsp := <-ch: require.Error(t, rsp.err) return rsp.err case <-time.After(1 * time.Second): t.Fatal("timeout waiting for WatchRootsResponse") return nil } } func handleRootsStream(t *testing.T, stream pbconnectca.ConnectCAService_WatchRootsClient) <-chan rootsOrError { t.Helper() rspCh := make(chan rootsOrError) go func() { for { rsp, err := stream.Recv() if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return } rspCh <- rootsOrError{ rsp: rsp, err: err, } } }() return rspCh } type rootsOrError struct { rsp *pbconnectca.WatchRootsResponse err error }