diff --git a/agent/consul/server.go b/agent/consul/server.go index 2ef6aee71d..cd31ca2807 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -982,7 +982,7 @@ func (s *Server) registerControllers(deps Deps, proxyUpdater ProxyUpdater) error } if s.useV2Resources { - catalog.RegisterControllers(s.controllerManager, catalog.DefaultControllerDependencies()) + catalog.RegisterControllers(s.controllerManager) multicluster.RegisterControllers(s.controllerManager, multicluster.DefaultControllerDependencies()) defaultAllow, err := s.config.ACLResolverSettings.IsDefaultAllow() if err != nil { diff --git a/internal/catalog/catalogtest/run_test.go b/internal/catalog/catalogtest/run_test.go index c5fa7ad39e..e2a4718aa0 100644 --- a/internal/catalog/catalogtest/run_test.go +++ b/internal/catalog/catalogtest/run_test.go @@ -7,8 +7,6 @@ import ( "testing" "github.com/hashicorp/consul/internal/catalog" - "github.com/hashicorp/consul/internal/catalog/internal/controllers" - "github.com/hashicorp/consul/internal/controller" "github.com/hashicorp/consul/internal/controller/controllertest" "github.com/hashicorp/consul/internal/resource/reaper" rtest "github.com/hashicorp/consul/internal/resource/resourcetest" @@ -19,25 +17,23 @@ var ( clientOpts = rtest.ConfigureTestCLIFlags() ) -func runInMemResourceServiceAndControllers(t *testing.T, deps controllers.Dependencies) pbresource.ResourceServiceClient { +func runInMemResourceServiceAndControllers(t *testing.T) pbresource.ResourceServiceClient { t.Helper() return controllertest.NewControllerTestBuilder(). WithResourceRegisterFns(catalog.RegisterTypes). WithControllerRegisterFns( reaper.RegisterControllers, - func(mgr *controller.Manager) { - catalog.RegisterControllers(mgr, deps) - }, + catalog.RegisterControllers, ).Run(t) } func TestControllers_Integration(t *testing.T) { - client := runInMemResourceServiceAndControllers(t, catalog.DefaultControllerDependencies()) + client := runInMemResourceServiceAndControllers(t) RunCatalogV2Beta1IntegrationTest(t, client, clientOpts.ClientOptions(t)...) } func TestControllers_Lifecycle(t *testing.T) { - client := runInMemResourceServiceAndControllers(t, catalog.DefaultControllerDependencies()) + client := runInMemResourceServiceAndControllers(t) RunCatalogV2Beta1LifecycleIntegrationTest(t, client, clientOpts.ClientOptions(t)...) } diff --git a/internal/catalog/exports.go b/internal/catalog/exports.go index 546fe30e2d..d864d7acb1 100644 --- a/internal/catalog/exports.go +++ b/internal/catalog/exports.go @@ -9,7 +9,6 @@ import ( "github.com/hashicorp/consul/internal/catalog/internal/controllers/failover" "github.com/hashicorp/consul/internal/catalog/internal/controllers/nodehealth" "github.com/hashicorp/consul/internal/catalog/internal/controllers/workloadhealth" - "github.com/hashicorp/consul/internal/catalog/internal/mappers/failovermapper" "github.com/hashicorp/consul/internal/catalog/internal/types" "github.com/hashicorp/consul/internal/controller" "github.com/hashicorp/consul/internal/resource" @@ -36,7 +35,7 @@ var ( StatusReasonWorkloadIdentitiesFound = endpoints.StatusReasonWorkloadIdentitiesFound StatusReasonNoWorkloadIdentitiesFound = endpoints.StatusReasonNoWorkloadIdentitiesFound - FailoverStatusKey = failover.StatusKey + FailoverStatusKey = failover.ControllerID FailoverStatusConditionAccepted = failover.StatusConditionAccepted FailoverStatusConditionAcceptedOKReason = failover.OKReason FailoverStatusConditionAcceptedMissingServiceReason = failover.MissingServiceReason @@ -52,18 +51,10 @@ func RegisterTypes(r resource.Registry) { types.Register(r) } -type ControllerDependencies = controllers.Dependencies - -func DefaultControllerDependencies() ControllerDependencies { - return ControllerDependencies{ - FailoverMapper: failovermapper.New(), - } -} - // RegisterControllers registers controllers for the catalog types with // the given controller Manager. -func RegisterControllers(mgr *controller.Manager, deps ControllerDependencies) { - controllers.Register(mgr, deps) +func RegisterControllers(mgr *controller.Manager) { + controllers.Register(mgr) } // SimplifyFailoverPolicy fully populates the PortConfigs map and clears the @@ -72,18 +63,6 @@ func SimplifyFailoverPolicy(svc *pbcatalog.Service, failover *pbcatalog.Failover return types.SimplifyFailoverPolicy(svc, failover) } -// FailoverPolicyMapper maintains the bidirectional tracking relationship of a -// FailoverPolicy to the Services related to it. -type FailoverPolicyMapper interface { - TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy]) - UntrackFailover(failoverID *pbresource.ID) - FailoverIDsByService(svcID *pbresource.ID) []*pbresource.ID -} - -func NewFailoverPolicyMapper() FailoverPolicyMapper { - return failovermapper.New() -} - // ValidateLocalServiceRefNoSection ensures the following: // // - ref is non-nil diff --git a/internal/catalog/internal/controllers/failover/controller.go b/internal/catalog/internal/controllers/failover/controller.go index 4dcf4f432c..a8fdcc5d07 100644 --- a/internal/catalog/internal/controllers/failover/controller.go +++ b/internal/catalog/internal/controllers/failover/controller.go @@ -8,72 +8,71 @@ import ( "github.com/hashicorp/consul/internal/catalog/internal/types" "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/controller/cache" + "github.com/hashicorp/consul/internal/controller/cache/indexers" + "github.com/hashicorp/consul/internal/controller/dependency" "github.com/hashicorp/consul/internal/resource" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" "github.com/hashicorp/consul/proto-public/pbresource" ) -// FailoverMapper tracks the relationship between a FailoverPolicy an a Service -// it references whether due to name-alignment or from a reference in a -// FailoverDestination leg. -type FailoverMapper interface { - // TrackFailover extracts all Service references from the provided - // FailoverPolicy and indexes them so that MapService can turn Service - // events into FailoverPolicy events properly. - TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy]) +const ( + destRefsIndexName = "destination-refs" +) - // UntrackFailover forgets the links inserted by TrackFailover for the - // provided FailoverPolicyID. - UntrackFailover(failoverID *pbresource.ID) - - // MapService will take a Service resource and return controller requests - // for all FailoverPolicies associated with the Service. - MapService(ctx context.Context, rt controller.Runtime, res *pbresource.Resource) ([]controller.Request, error) +func FailoverPolicyController() *controller.Controller { + return controller.NewController( + ControllerID, + pbcatalog.FailoverPolicyType, + // We index the destination references of a failover policy so that when the + // Service watch fires we can find all FailoverPolicy resources that reference + // it to rereconcile them. + indexers.RefOrIDIndex( + destRefsIndexName, + func(res *resource.DecodedResource[*pbcatalog.FailoverPolicy]) []*pbresource.Reference { + return res.Data.GetUnderlyingDestinationRefs() + }, + )). + WithWatch( + pbcatalog.ServiceType, + dependency.MultiMapper( + // FailoverPolicy is name-aligned with the Service it controls so always + // re-reconcile the corresponding FailoverPolicy when a Service changes. + dependency.ReplaceType(pbcatalog.FailoverPolicyType), + // Also check for all FailoverPolicy resources that have this service as a + // destination and re-reconcile those to check for port mapping conflicts. + dependency.CacheListMapper(pbcatalog.FailoverPolicyType, destRefsIndexName), + ), + ). + WithReconciler(newFailoverPolicyReconciler()) } -func FailoverPolicyController(mapper FailoverMapper) *controller.Controller { - if mapper == nil { - panic("No FailoverMapper was provided to the FailoverPolicyController constructor") - } - return controller.NewController(StatusKey, pbcatalog.FailoverPolicyType). - WithWatch(pbcatalog.ServiceType, mapper.MapService). - WithReconciler(newFailoverPolicyReconciler(mapper)) -} +type failoverPolicyReconciler struct{} -type failoverPolicyReconciler struct { - mapper FailoverMapper -} - -func newFailoverPolicyReconciler(mapper FailoverMapper) *failoverPolicyReconciler { - return &failoverPolicyReconciler{ - mapper: mapper, - } +func newFailoverPolicyReconciler() *failoverPolicyReconciler { + return &failoverPolicyReconciler{} } func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller.Runtime, req controller.Request) error { // The runtime is passed by value so replacing it here for the remainder of this // reconciliation request processing will not affect future invocations. - rt.Logger = rt.Logger.With("resource-id", req.ID, "controller", StatusKey) + rt.Logger = rt.Logger.With("resource-id", req.ID) rt.Logger.Trace("reconciling failover policy") failoverPolicyID := req.ID - failoverPolicy, err := getFailoverPolicy(ctx, rt, failoverPolicyID) + failoverPolicy, err := cache.GetDecoded[*pbcatalog.FailoverPolicy](rt.Cache, pbcatalog.FailoverPolicyType, "id", failoverPolicyID) if err != nil { rt.Logger.Error("error retrieving failover policy", "error", err) return err } if failoverPolicy == nil { - r.mapper.UntrackFailover(failoverPolicyID) - // Either the failover policy was deleted, or it doesn't exist but an // update to a Service came through and we can ignore it. return nil } - r.mapper.TrackFailover(failoverPolicy) - // FailoverPolicy is name-aligned with the Service it controls. serviceID := &pbresource.ID{ Type: pbcatalog.ServiceType, @@ -81,7 +80,7 @@ func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller. Name: failoverPolicyID.Name, } - service, err := getService(ctx, rt, serviceID) + service, err := cache.GetDecoded[*pbcatalog.Service](rt.Cache, pbcatalog.ServiceType, "id", serviceID) if err != nil { rt.Logger.Error("error retrieving corresponding service", "error", err) return err @@ -91,7 +90,7 @@ func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller. destServices[resource.NewReferenceKey(serviceID)] = service } - // Denorm the ports and stuff. After this we have no empty ports. + // Denormalize the ports and stuff. After this we have no empty ports. if service != nil { failoverPolicy.Data = types.SimplifyFailoverPolicy( service.Data, @@ -113,7 +112,7 @@ func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller. destID := resource.IDFromReference(dest.Ref) - destService, err := getService(ctx, rt, destID) + destService, err := cache.GetDecoded[*pbcatalog.Service](rt.Cache, pbcatalog.ServiceType, "id", destID) if err != nil { rt.Logger.Error("error retrieving destination service", "service", key, "error", err) return err @@ -126,7 +125,7 @@ func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller. newStatus := computeNewStatus(failoverPolicy, service, destServices) - if resource.EqualStatus(failoverPolicy.Resource.Status[StatusKey], newStatus, false) { + if resource.EqualStatus(failoverPolicy.Resource.Status[ControllerID], newStatus, false) { rt.Logger.Trace("resource's failover policy status is unchanged", "conditions", newStatus.Conditions) return nil @@ -134,7 +133,7 @@ func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller. _, err = rt.Client.WriteStatus(ctx, &pbresource.WriteStatusRequest{ Id: failoverPolicy.Resource.Id, - Key: StatusKey, + Key: ControllerID, Status: newStatus, }) @@ -148,14 +147,6 @@ func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller. return nil } -func getFailoverPolicy(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[*pbcatalog.FailoverPolicy], error) { - return resource.GetDecodedResource[*pbcatalog.FailoverPolicy](ctx, rt.Client, id) -} - -func getService(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[*pbcatalog.Service], error) { - return resource.GetDecodedResource[*pbcatalog.Service](ctx, rt.Client, id) -} - func computeNewStatus( failoverPolicy *resource.DecodedResource[*pbcatalog.FailoverPolicy], service *resource.DecodedResource[*pbcatalog.Service], diff --git a/internal/catalog/internal/controllers/failover/controller_test.go b/internal/catalog/internal/controllers/failover/controller_test.go index 92eca80af9..7f3f83e96d 100644 --- a/internal/catalog/internal/controllers/failover/controller_test.go +++ b/internal/catalog/internal/controllers/failover/controller_test.go @@ -4,336 +4,302 @@ package failover import ( - "context" "fmt" "testing" - "github.com/stretchr/testify/suite" - - svctest "github.com/hashicorp/consul/agent/grpc-external/services/resource/testing" - "github.com/hashicorp/consul/internal/catalog/internal/mappers/failovermapper" "github.com/hashicorp/consul/internal/catalog/internal/types" "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/controller/controllertest" "github.com/hashicorp/consul/internal/resource" + "github.com/hashicorp/consul/internal/resource/resourcetest" rtest "github.com/hashicorp/consul/internal/resource/resourcetest" pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" "github.com/hashicorp/consul/proto-public/pbresource" - "github.com/hashicorp/consul/sdk/testutil" ) -type controllerSuite struct { - suite.Suite - - ctx context.Context - client *rtest.Client - rt controller.Runtime - - failoverMapper FailoverMapper - - ctl failoverPolicyReconciler - - tenancies []*pbresource.Tenancy -} - -func (suite *controllerSuite) SetupTest() { - suite.tenancies = rtest.TestTenancies() - client := svctest.NewResourceServiceBuilder(). - WithRegisterFns(types.Register). - WithTenancies(suite.tenancies...). - Run(suite.T()) - - suite.rt = controller.Runtime{ - Client: client, - Logger: testutil.Logger(suite.T()), - } - suite.client = rtest.NewClient(client) - suite.failoverMapper = failovermapper.New() - suite.ctx = testutil.TestContext(suite.T()) -} - -func (suite *controllerSuite) TestController() { +func TestController(t *testing.T) { // This test's purpose is to exercise the controller in a halfway realistic // way, verifying the event triggers work in the live code. - // Run the controller manager - mgr := controller.NewManager(suite.client, suite.rt.Logger) - mgr.Register(FailoverPolicyController(suite.failoverMapper)) - mgr.SetRaftLeader(true) - go mgr.Run(suite.ctx) + clientRaw := controllertest.NewControllerTestBuilder(). + WithTenancies(resourcetest.TestTenancies()...). + WithResourceRegisterFns(types.Register). + WithControllerRegisterFns(func(mgr *controller.Manager) { + mgr.Register(FailoverPolicyController()) + }). + Run(t) - suite.runTestCaseWithTenancies(func(tenancy *pbresource.Tenancy) { - // Create an advance pointer to some services. - apiServiceRef := resource.Reference(rtest.Resource(pbcatalog.ServiceType, "api").WithTenancy(tenancy).ID(), "") - otherServiceRef := resource.Reference(rtest.Resource(pbcatalog.ServiceType, "other").WithTenancy(tenancy).ID(), "") + client := rtest.NewClient(clientRaw) - // create a failover without any services - failoverData := &pbcatalog.FailoverPolicy{ - Config: &pbcatalog.FailoverConfig{ - Destinations: []*pbcatalog.FailoverDestination{{ - Ref: apiServiceRef, + for _, tenancy := range resourcetest.TestTenancies() { + t.Run(tenancySubTestName(tenancy), func(t *testing.T) { + tenancy := tenancy + + // Create an advance pointer to some services. + apiServiceRef := resource.Reference(rtest.Resource(pbcatalog.ServiceType, "api").WithTenancy(tenancy).ID(), "") + otherServiceRef := resource.Reference(rtest.Resource(pbcatalog.ServiceType, "other").WithTenancy(tenancy).ID(), "") + + // create a failover without any services + failoverData := &pbcatalog.FailoverPolicy{ + Config: &pbcatalog.FailoverConfig{ + Destinations: []*pbcatalog.FailoverDestination{{ + Ref: apiServiceRef, + }}, + }, + } + failover := rtest.Resource(pbcatalog.FailoverPolicyType, "api"). + WithData(t, failoverData). + WithTenancy(tenancy). + Write(t, client) + + t.Cleanup(func() { client.MustDelete(t, failover.Id) }) + + client.WaitForStatusCondition(t, failover.Id, ControllerID, ConditionMissingService) + t.Logf("reconciled to missing service status") + + // Provide the service. + apiServiceData := &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"api-"}}, + Ports: []*pbcatalog.ServicePort{{ + TargetPort: "http", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, }}, - }, - } - failover := rtest.Resource(pbcatalog.FailoverPolicyType, "api"). - WithData(suite.T(), failoverData). - WithTenancy(tenancy). - Write(suite.T(), suite.client) + } + svc := rtest.Resource(pbcatalog.ServiceType, "api"). + WithData(t, apiServiceData). + WithTenancy(tenancy). + Write(t, client) - suite.T().Cleanup(suite.deleteResourceFunc(failover.Id)) + t.Cleanup(func() { client.MustDelete(t, svc.Id) }) - suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionMissingService) + client.WaitForStatusCondition(t, failover.Id, ControllerID, ConditionOK) + t.Logf("reconciled to accepted") - // Provide the service. - apiServiceData := &pbcatalog.Service{ - Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"api-"}}, - Ports: []*pbcatalog.ServicePort{{ - TargetPort: "http", - Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, - }}, - } - svc := rtest.Resource(pbcatalog.ServiceType, "api"). - WithData(suite.T(), apiServiceData). - WithTenancy(tenancy). - Write(suite.T(), suite.client) - - suite.T().Cleanup(suite.deleteResourceFunc(svc.Id)) - - suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionOK) - - // Update the failover to reference an unknown port - failoverData = &pbcatalog.FailoverPolicy{ - PortConfigs: map[string]*pbcatalog.FailoverConfig{ - "http": { - Destinations: []*pbcatalog.FailoverDestination{{ - Ref: apiServiceRef, - Port: "http", - }}, + // Update the failover to reference an unknown port + failoverData = &pbcatalog.FailoverPolicy{ + PortConfigs: map[string]*pbcatalog.FailoverConfig{ + "http": { + Destinations: []*pbcatalog.FailoverDestination{{ + Ref: apiServiceRef, + Port: "http", + }}, + }, + "admin": { + Destinations: []*pbcatalog.FailoverDestination{{ + Ref: apiServiceRef, + Port: "admin", + }}, + }, }, - "admin": { - Destinations: []*pbcatalog.FailoverDestination{{ - Ref: apiServiceRef, - Port: "admin", - }}, + } + svc = rtest.Resource(pbcatalog.FailoverPolicyType, "api"). + WithData(t, failoverData). + WithTenancy(tenancy). + Write(t, client) + + t.Cleanup(func() { client.MustDelete(t, svc.Id) }) + + client.WaitForStatusCondition(t, failover.Id, ControllerID, ConditionUnknownPort("admin")) + t.Logf("reconciled to unknown admin port") + + // update the service to fix the stray reference, but point to a mesh port + apiServiceData = &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"api-"}}, + Ports: []*pbcatalog.ServicePort{ + { + TargetPort: "http", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + { + TargetPort: "admin", + Protocol: pbcatalog.Protocol_PROTOCOL_MESH, + }, }, - }, - } - svc = rtest.Resource(pbcatalog.FailoverPolicyType, "api"). - WithData(suite.T(), failoverData). - WithTenancy(tenancy). - Write(suite.T(), suite.client) + } + svc = rtest.Resource(pbcatalog.ServiceType, "api"). + WithData(t, apiServiceData). + WithTenancy(tenancy). + Write(t, client) - suite.T().Cleanup(suite.deleteResourceFunc(svc.Id)) + t.Cleanup(func() { client.MustDelete(t, svc.Id) }) - suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionUnknownPort("admin")) + client.WaitForStatusCondition(t, failover.Id, ControllerID, ConditionUsingMeshDestinationPort(apiServiceRef, "admin")) + t.Logf("reconciled to using mesh destination port") - // update the service to fix the stray reference, but point to a mesh port - apiServiceData = &pbcatalog.Service{ - Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"api-"}}, - Ports: []*pbcatalog.ServicePort{ - { + // update the service to fix the stray reference to not be a mesh port + apiServiceData = &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"api-"}}, + Ports: []*pbcatalog.ServicePort{ + { + TargetPort: "http", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + { + TargetPort: "admin", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + }, + } + svc = rtest.Resource(pbcatalog.ServiceType, "api"). + WithData(t, apiServiceData). + WithTenancy(tenancy). + Write(t, client) + + t.Cleanup(func() { client.MustDelete(t, svc.Id) }) + + client.WaitForStatusCondition(t, failover.Id, ControllerID, ConditionOK) + t.Logf("reconciled to accepted") + + // change failover leg to point to missing service + failoverData = &pbcatalog.FailoverPolicy{ + PortConfigs: map[string]*pbcatalog.FailoverConfig{ + "http": { + Destinations: []*pbcatalog.FailoverDestination{{ + Ref: apiServiceRef, + Port: "http", + }}, + }, + "admin": { + Destinations: []*pbcatalog.FailoverDestination{{ + Ref: otherServiceRef, + Port: "admin", + }}, + }, + }, + } + svc = rtest.Resource(pbcatalog.FailoverPolicyType, "api"). + WithData(t, failoverData). + WithTenancy(tenancy). + Write(t, client) + + t.Cleanup(func() { client.MustDelete(t, svc.Id) }) + + client.WaitForStatusCondition(t, failover.Id, ControllerID, ConditionMissingDestinationService(otherServiceRef)) + t.Logf("reconciled to missing dest service: other") + + // Create the missing service, but forget the port. + otherServiceData := &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"other-"}}, + Ports: []*pbcatalog.ServicePort{{ TargetPort: "http", Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, - }, - { - TargetPort: "admin", - Protocol: pbcatalog.Protocol_PROTOCOL_MESH, - }, - }, - } - svc = rtest.Resource(pbcatalog.ServiceType, "api"). - WithData(suite.T(), apiServiceData). - WithTenancy(tenancy). - Write(suite.T(), suite.client) - - suite.T().Cleanup(suite.deleteResourceFunc(svc.Id)) - - suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionUsingMeshDestinationPort(apiServiceRef, "admin")) - - // update the service to fix the stray reference to not be a mesh port - apiServiceData = &pbcatalog.Service{ - Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"api-"}}, - Ports: []*pbcatalog.ServicePort{ - { - TargetPort: "http", - Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, - }, - { - TargetPort: "admin", - Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, - }, - }, - } - svc = rtest.Resource(pbcatalog.ServiceType, "api"). - WithData(suite.T(), apiServiceData). - WithTenancy(tenancy). - Write(suite.T(), suite.client) - - suite.T().Cleanup(suite.deleteResourceFunc(svc.Id)) - - suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionOK) - - // change failover leg to point to missing service - failoverData = &pbcatalog.FailoverPolicy{ - PortConfigs: map[string]*pbcatalog.FailoverConfig{ - "http": { - Destinations: []*pbcatalog.FailoverDestination{{ - Ref: apiServiceRef, - Port: "http", - }}, - }, - "admin": { - Destinations: []*pbcatalog.FailoverDestination{{ - Ref: otherServiceRef, - Port: "admin", - }}, - }, - }, - } - svc = rtest.Resource(pbcatalog.FailoverPolicyType, "api"). - WithData(suite.T(), failoverData). - WithTenancy(tenancy). - Write(suite.T(), suite.client) - - suite.T().Cleanup(suite.deleteResourceFunc(svc.Id)) - - suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionMissingDestinationService(otherServiceRef)) - - // Create the missing service, but forget the port. - otherServiceData := &pbcatalog.Service{ - Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"other-"}}, - Ports: []*pbcatalog.ServicePort{{ - TargetPort: "http", - Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, - }}, - } - svc = rtest.Resource(pbcatalog.ServiceType, "other"). - WithData(suite.T(), otherServiceData). - WithTenancy(tenancy). - Write(suite.T(), suite.client) - - suite.T().Cleanup(suite.deleteResourceFunc(svc.Id)) - - suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionUnknownDestinationPort(otherServiceRef, "admin")) - - // fix the destination leg's port - otherServiceData = &pbcatalog.Service{ - Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"other-"}}, - Ports: []*pbcatalog.ServicePort{ - { - TargetPort: "http", - Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, - }, - { - TargetPort: "admin", - Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, - }, - }, - } - svc = rtest.Resource(pbcatalog.ServiceType, "other"). - WithData(suite.T(), otherServiceData). - WithTenancy(tenancy). - Write(suite.T(), suite.client) - - suite.T().Cleanup(suite.deleteResourceFunc(svc.Id)) - - suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionOK) - - // Update the two services to use differnet port names so the easy path doesn't work - apiServiceData = &pbcatalog.Service{ - Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"api-"}}, - Ports: []*pbcatalog.ServicePort{ - { - TargetPort: "foo", - Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, - }, - { - TargetPort: "bar", - Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, - }, - }, - } - svc = rtest.Resource(pbcatalog.ServiceType, "api"). - WithData(suite.T(), apiServiceData). - WithTenancy(tenancy). - Write(suite.T(), suite.client) - - suite.T().Cleanup(suite.deleteResourceFunc(svc.Id)) - - otherServiceData = &pbcatalog.Service{ - Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"other-"}}, - Ports: []*pbcatalog.ServicePort{ - { - TargetPort: "foo", - Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, - }, - { - TargetPort: "baz", - Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, - }, - }, - } - svc = rtest.Resource(pbcatalog.ServiceType, "other"). - WithData(suite.T(), otherServiceData). - WithTenancy(tenancy). - Write(suite.T(), suite.client) - - suite.T().Cleanup(suite.deleteResourceFunc(svc.Id)) - - failoverData = &pbcatalog.FailoverPolicy{ - Config: &pbcatalog.FailoverConfig{ - Destinations: []*pbcatalog.FailoverDestination{{ - Ref: otherServiceRef, }}, - }, - } - failover = rtest.Resource(pbcatalog.FailoverPolicyType, "api"). - WithData(suite.T(), failoverData). - WithTenancy(tenancy). - Write(suite.T(), suite.client) + } + svc = rtest.Resource(pbcatalog.ServiceType, "other"). + WithData(t, otherServiceData). + WithTenancy(tenancy). + Write(t, client) - suite.T().Cleanup(suite.deleteResourceFunc(failover.Id)) + t.Cleanup(func() { client.MustDelete(t, svc.Id) }) - suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionUnknownDestinationPort(otherServiceRef, "bar")) + client.WaitForStatusCondition(t, failover.Id, ControllerID, ConditionUnknownDestinationPort(otherServiceRef, "admin")) + t.Logf("reconciled to missing dest port other:admin") - // and fix it the silly way by removing it from api+failover - apiServiceData = &pbcatalog.Service{ - Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"api-"}}, - Ports: []*pbcatalog.ServicePort{ - { - TargetPort: "foo", - Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + // fix the destination leg's port + otherServiceData = &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"other-"}}, + Ports: []*pbcatalog.ServicePort{ + { + TargetPort: "http", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + { + TargetPort: "admin", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, }, - }, - } - svc = rtest.Resource(pbcatalog.ServiceType, "api"). - WithData(suite.T(), apiServiceData). - WithTenancy(tenancy). - Write(suite.T(), suite.client) + } + svc = rtest.Resource(pbcatalog.ServiceType, "other"). + WithData(t, otherServiceData). + WithTenancy(tenancy). + Write(t, client) - suite.T().Cleanup(suite.deleteResourceFunc(svc.Id)) + t.Cleanup(func() { client.MustDelete(t, svc.Id) }) - suite.client.WaitForStatusCondition(suite.T(), failover.Id, StatusKey, ConditionOK) - }) -} + client.WaitForStatusCondition(t, failover.Id, ControllerID, ConditionOK) + t.Logf("reconciled to accepted") -func TestFailoverController(t *testing.T) { - suite.Run(t, new(controllerSuite)) -} + // Update the two services to use differnet port names so the easy path doesn't work + apiServiceData = &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"api-"}}, + Ports: []*pbcatalog.ServicePort{ + { + TargetPort: "foo", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + { + TargetPort: "bar", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + }, + } + svc = rtest.Resource(pbcatalog.ServiceType, "api"). + WithData(t, apiServiceData). + WithTenancy(tenancy). + Write(t, client) -func (suite *controllerSuite) runTestCaseWithTenancies(testCase func(tenancy *pbresource.Tenancy)) { - for _, tenancy := range suite.tenancies { - suite.Run(suite.appendTenancyInfo(tenancy), func() { - testCase(tenancy) + t.Cleanup(func() { client.MustDelete(t, svc.Id) }) + + otherServiceData = &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"other-"}}, + Ports: []*pbcatalog.ServicePort{ + { + TargetPort: "foo", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + { + TargetPort: "baz", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + }, + } + svc = rtest.Resource(pbcatalog.ServiceType, "other"). + WithData(t, otherServiceData). + WithTenancy(tenancy). + Write(t, client) + + t.Cleanup(func() { client.MustDelete(t, svc.Id) }) + + failoverData = &pbcatalog.FailoverPolicy{ + Config: &pbcatalog.FailoverConfig{ + Destinations: []*pbcatalog.FailoverDestination{{ + Ref: otherServiceRef, + }}, + }, + } + failover = rtest.Resource(pbcatalog.FailoverPolicyType, "api"). + WithData(t, failoverData). + WithTenancy(tenancy). + Write(t, client) + + t.Cleanup(func() { client.MustDelete(t, failover.Id) }) + + client.WaitForStatusCondition(t, failover.Id, ControllerID, ConditionUnknownDestinationPort(otherServiceRef, "bar")) + t.Logf("reconciled to missing dest port other:bar") + + // and fix it the silly way by removing it from api+failover + apiServiceData = &pbcatalog.Service{ + Workloads: &pbcatalog.WorkloadSelector{Prefixes: []string{"api-"}}, + Ports: []*pbcatalog.ServicePort{ + { + TargetPort: "foo", + Protocol: pbcatalog.Protocol_PROTOCOL_HTTP, + }, + }, + } + svc = rtest.Resource(pbcatalog.ServiceType, "api"). + WithData(t, apiServiceData). + WithTenancy(tenancy). + Write(t, client) + + t.Cleanup(func() { client.MustDelete(t, svc.Id) }) + + client.WaitForStatusCondition(t, failover.Id, ControllerID, ConditionOK) + t.Logf("reconciled to accepted") }) } } -func (suite *controllerSuite) appendTenancyInfo(tenancy *pbresource.Tenancy) string { +func tenancySubTestName(tenancy *pbresource.Tenancy) string { return fmt.Sprintf("%s_Namespace_%s_Partition", tenancy.Namespace, tenancy.Partition) } - -func (suite *controllerSuite) deleteResourceFunc(id *pbresource.ID) func() { - return func() { - suite.client.MustDelete(suite.T(), id) - } -} diff --git a/internal/catalog/internal/controllers/failover/status.go b/internal/catalog/internal/controllers/failover/status.go index b2801c41ed..d54918bc5a 100644 --- a/internal/catalog/internal/controllers/failover/status.go +++ b/internal/catalog/internal/controllers/failover/status.go @@ -9,7 +9,7 @@ import ( ) const ( - StatusKey = "consul.io/failover-policy" + ControllerID = "consul.io/failover-policy" StatusConditionAccepted = "accepted" OKReason = "Ok" diff --git a/internal/catalog/internal/controllers/register.go b/internal/catalog/internal/controllers/register.go index f3352e6de6..255c66a2a9 100644 --- a/internal/catalog/internal/controllers/register.go +++ b/internal/catalog/internal/controllers/register.go @@ -11,13 +11,9 @@ import ( "github.com/hashicorp/consul/internal/controller" ) -type Dependencies struct { - FailoverMapper failover.FailoverMapper -} - -func Register(mgr *controller.Manager, deps Dependencies) { +func Register(mgr *controller.Manager) { mgr.Register(nodehealth.NodeHealthController()) mgr.Register(workloadhealth.WorkloadHealthController()) mgr.Register(endpoints.ServiceEndpointsController()) - mgr.Register(failover.FailoverPolicyController(deps.FailoverMapper)) + mgr.Register(failover.FailoverPolicyController()) } diff --git a/internal/catalog/internal/mappers/failovermapper/failover_mapper.go b/internal/catalog/internal/mappers/failovermapper/failover_mapper.go deleted file mode 100644 index 7cd40b47b4..0000000000 --- a/internal/catalog/internal/mappers/failovermapper/failover_mapper.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) HashiCorp, Inc. -// SPDX-License-Identifier: BUSL-1.1 - -package failovermapper - -import ( - "context" - - "github.com/hashicorp/consul/internal/controller" - "github.com/hashicorp/consul/internal/resource" - "github.com/hashicorp/consul/internal/resource/mappers/bimapper" - pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" - "github.com/hashicorp/consul/proto-public/pbresource" -) - -// Mapper tracks the relationship between a FailoverPolicy an a Service it -// references whether due to name-alignment or from a reference in a -// FailoverDestination leg. -type Mapper struct { - b *bimapper.Mapper -} - -// New creates a new Mapper. -func New() *Mapper { - return &Mapper{ - b: bimapper.New(pbcatalog.FailoverPolicyType, pbcatalog.ServiceType), - } -} - -// TrackFailover extracts all Service references from the provided -// FailoverPolicy and indexes them so that MapService can turn Service events -// into FailoverPolicy events properly. -func (m *Mapper) TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy]) { - destRefs := failover.Data.GetUnderlyingDestinationRefs() - destRefs = append(destRefs, &pbresource.Reference{ - Type: pbcatalog.ServiceType, - Tenancy: failover.Resource.Id.Tenancy, - Name: failover.Resource.Id.Name, - }) - m.trackFailover(failover.Resource.Id, destRefs) -} - -func (m *Mapper) trackFailover(failover *pbresource.ID, services []*pbresource.Reference) { - var servicesAsIDsOrRefs []resource.ReferenceOrID - for _, s := range services { - servicesAsIDsOrRefs = append(servicesAsIDsOrRefs, s) - } - m.b.TrackItem(failover, servicesAsIDsOrRefs) -} - -// UntrackFailover forgets the links inserted by TrackFailover for the provided -// FailoverPolicyID. -func (m *Mapper) UntrackFailover(failoverID *pbresource.ID) { - m.b.UntrackItem(failoverID) -} - -func (m *Mapper) MapService(ctx context.Context, rt controller.Runtime, res *pbresource.Resource) ([]controller.Request, error) { - return m.b.MapLink(ctx, rt, res) -} - -func (m *Mapper) FailoverIDsByService(svcID *pbresource.ID) []*pbresource.ID { - return m.b.ItemsForLink(svcID) -} diff --git a/internal/catalog/internal/mappers/failovermapper/failover_mapper_test.go b/internal/catalog/internal/mappers/failovermapper/failover_mapper_test.go deleted file mode 100644 index 6149fddbf2..0000000000 --- a/internal/catalog/internal/mappers/failovermapper/failover_mapper_test.go +++ /dev/null @@ -1,192 +0,0 @@ -// Copyright (c) HashiCorp, Inc. -// SPDX-License-Identifier: BUSL-1.1 - -package failovermapper - -import ( - "context" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/hashicorp/consul/internal/catalog/internal/types" - "github.com/hashicorp/consul/internal/controller" - "github.com/hashicorp/consul/internal/resource" - rtest "github.com/hashicorp/consul/internal/resource/resourcetest" - pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" - "github.com/hashicorp/consul/proto-public/pbresource" - "github.com/hashicorp/consul/proto/private/prototest" -) - -func TestMapper_Tracking(t *testing.T) { - registry := resource.NewRegistry() - types.Register(registry) - - // Create an advance pointer to some services. - randoSvc := rtest.Resource(pbcatalog.ServiceType, "rando"). - WithTenancy(resource.DefaultNamespacedTenancy()). - WithData(t, &pbcatalog.Service{}). - Build() - rtest.ValidateAndNormalize(t, registry, randoSvc) - - apiSvc := rtest.Resource(pbcatalog.ServiceType, "api"). - WithTenancy(resource.DefaultNamespacedTenancy()). - WithData(t, &pbcatalog.Service{}). - Build() - rtest.ValidateAndNormalize(t, registry, apiSvc) - - fooSvc := rtest.Resource(pbcatalog.ServiceType, "foo"). - WithTenancy(resource.DefaultNamespacedTenancy()). - WithData(t, &pbcatalog.Service{}). - Build() - rtest.ValidateAndNormalize(t, registry, fooSvc) - - barSvc := rtest.Resource(pbcatalog.ServiceType, "bar"). - WithTenancy(resource.DefaultNamespacedTenancy()). - WithData(t, &pbcatalog.Service{}). - Build() - rtest.ValidateAndNormalize(t, registry, barSvc) - - wwwSvc := rtest.Resource(pbcatalog.ServiceType, "www"). - WithTenancy(resource.DefaultNamespacedTenancy()). - WithData(t, &pbcatalog.Service{}). - Build() - rtest.ValidateAndNormalize(t, registry, wwwSvc) - - fail1 := rtest.Resource(pbcatalog.FailoverPolicyType, "api"). - WithTenancy(resource.DefaultNamespacedTenancy()). - WithData(t, &pbcatalog.FailoverPolicy{ - Config: &pbcatalog.FailoverConfig{ - Destinations: []*pbcatalog.FailoverDestination{ - {Ref: newRef(pbcatalog.ServiceType, "foo")}, - {Ref: newRef(pbcatalog.ServiceType, "bar")}, - }, - }, - }). - Build() - rtest.ValidateAndNormalize(t, registry, fail1) - failDec1 := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail1) - - fail2 := rtest.Resource(pbcatalog.FailoverPolicyType, "www"). - WithTenancy(resource.DefaultNamespacedTenancy()). - WithData(t, &pbcatalog.FailoverPolicy{ - Config: &pbcatalog.FailoverConfig{ - Destinations: []*pbcatalog.FailoverDestination{ - {Ref: newRef(pbcatalog.ServiceType, "www"), Datacenter: "dc2"}, - {Ref: newRef(pbcatalog.ServiceType, "foo")}, - }, - }, - }). - Build() - rtest.ValidateAndNormalize(t, registry, fail2) - failDec2 := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail2) - - fail1_updated := rtest.Resource(pbcatalog.FailoverPolicyType, "api"). - WithTenancy(resource.DefaultNamespacedTenancy()). - WithData(t, &pbcatalog.FailoverPolicy{ - Config: &pbcatalog.FailoverConfig{ - Destinations: []*pbcatalog.FailoverDestination{ - {Ref: newRef(pbcatalog.ServiceType, "bar")}, - }, - }, - }). - Build() - rtest.ValidateAndNormalize(t, registry, fail1_updated) - failDec1_updated := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail1_updated) - - m := New() - - // Nothing tracked yet so we assume nothing. - requireServicesTracked(t, m, randoSvc) - requireServicesTracked(t, m, apiSvc) - requireServicesTracked(t, m, fooSvc) - requireServicesTracked(t, m, barSvc) - requireServicesTracked(t, m, wwwSvc) - - // no-ops - m.UntrackFailover(fail1.Id) - - // still nothing - requireServicesTracked(t, m, randoSvc) - requireServicesTracked(t, m, apiSvc) - requireServicesTracked(t, m, fooSvc) - requireServicesTracked(t, m, barSvc) - requireServicesTracked(t, m, wwwSvc) - - // Actually insert some data. - m.TrackFailover(failDec1) - - requireServicesTracked(t, m, randoSvc) - requireServicesTracked(t, m, apiSvc, fail1.Id) - requireServicesTracked(t, m, fooSvc, fail1.Id) - requireServicesTracked(t, m, barSvc, fail1.Id) - requireServicesTracked(t, m, wwwSvc) - - // track it again, no change - m.TrackFailover(failDec1) - - requireServicesTracked(t, m, randoSvc) - requireServicesTracked(t, m, apiSvc, fail1.Id) - requireServicesTracked(t, m, fooSvc, fail1.Id) - requireServicesTracked(t, m, barSvc, fail1.Id) - requireServicesTracked(t, m, wwwSvc) - - // track new one that overlaps slightly - m.TrackFailover(failDec2) - - requireServicesTracked(t, m, randoSvc) - requireServicesTracked(t, m, apiSvc, fail1.Id) - requireServicesTracked(t, m, fooSvc, fail1.Id, fail2.Id) - requireServicesTracked(t, m, barSvc, fail1.Id) - requireServicesTracked(t, m, wwwSvc, fail2.Id) - - // update the original to change it - m.TrackFailover(failDec1_updated) - - requireServicesTracked(t, m, randoSvc) - requireServicesTracked(t, m, apiSvc, fail1.Id) - requireServicesTracked(t, m, fooSvc, fail2.Id) - requireServicesTracked(t, m, barSvc, fail1.Id) - requireServicesTracked(t, m, wwwSvc, fail2.Id) - - // delete the original - m.UntrackFailover(fail1.Id) - - requireServicesTracked(t, m, randoSvc) - requireServicesTracked(t, m, apiSvc) - requireServicesTracked(t, m, fooSvc, fail2.Id) - requireServicesTracked(t, m, barSvc) - requireServicesTracked(t, m, wwwSvc, fail2.Id) - - // delete the other one - m.UntrackFailover(fail2.Id) - - requireServicesTracked(t, m, randoSvc) - requireServicesTracked(t, m, apiSvc) - requireServicesTracked(t, m, fooSvc) - requireServicesTracked(t, m, barSvc) - requireServicesTracked(t, m, wwwSvc) -} - -func requireServicesTracked(t *testing.T, mapper *Mapper, svc *pbresource.Resource, failovers ...*pbresource.ID) { - t.Helper() - - reqs, err := mapper.MapService( - context.Background(), - controller.Runtime{}, - svc, - ) - require.NoError(t, err) - - require.Len(t, reqs, len(failovers)) - - for _, failover := range failovers { - prototest.AssertContainsElement(t, reqs, controller.Request{ID: failover}) - } -} - -func newRef(typ *pbresource.Type, name string) *pbresource.Reference { - return rtest.Resource(typ, name). - WithTenancy(resource.DefaultNamespacedTenancy()). - Reference("") -} diff --git a/internal/mesh/internal/controllers/routes/controller.go b/internal/mesh/internal/controllers/routes/controller.go index 9a9fabc076..7500129604 100644 --- a/internal/mesh/internal/controllers/routes/controller.go +++ b/internal/mesh/internal/controllers/routes/controller.go @@ -11,6 +11,7 @@ import ( "google.golang.org/protobuf/types/known/anypb" "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/controller/cache/indexers" "github.com/hashicorp/consul/internal/mesh/internal/controllers/routes/loader" "github.com/hashicorp/consul/internal/mesh/internal/controllers/routes/xroutemapper" "github.com/hashicorp/consul/internal/mesh/internal/types" @@ -20,8 +21,30 @@ import ( "github.com/hashicorp/consul/proto-public/pbresource" ) +const ( + failoverDestRefsIndexName = "destination-refs" +) + +func resolveFailoverDestRefs(_ context.Context, rt controller.Runtime, id *pbresource.ID) ([]*pbresource.ID, error) { + iter, err := rt.Cache.ListIterator(pbcatalog.FailoverPolicyType, failoverDestRefsIndexName, id) + if err != nil { + return nil, err + } + + var resolved []*pbresource.ID + for res := iter.Next(); res != nil; res = iter.Next() { + resolved = append(resolved, resource.ReplaceType(pbcatalog.ServiceType, res.Id)) + } + + return resolved, nil +} + func Controller() *controller.Controller { - mapper := xroutemapper.New() + failoverDestRefsIndex := indexers.RefOrIDIndex(failoverDestRefsIndexName, func(dec *resource.DecodedResource[*pbcatalog.FailoverPolicy]) []*pbresource.Reference { + return dec.Data.GetUnderlyingDestinationRefs() + }) + + mapper := xroutemapper.New(resolveFailoverDestRefs) r := &routesReconciler{ mapper: mapper, @@ -30,8 +53,8 @@ func Controller() *controller.Controller { WithWatch(pbmesh.HTTPRouteType, mapper.MapHTTPRoute). WithWatch(pbmesh.GRPCRouteType, mapper.MapGRPCRoute). WithWatch(pbmesh.TCPRouteType, mapper.MapTCPRoute). - WithWatch(pbmesh.DestinationPolicyType, mapper.MapDestinationPolicy). - WithWatch(pbcatalog.FailoverPolicyType, mapper.MapFailoverPolicy). + WithWatch(pbmesh.DestinationPolicyType, mapper.MapServiceNameAligned). + WithWatch(pbcatalog.FailoverPolicyType, mapper.MapServiceNameAligned, failoverDestRefsIndex). WithWatch(pbcatalog.ServiceType, mapper.MapService). WithReconciler(r) } diff --git a/internal/mesh/internal/controllers/routes/controller_test.go b/internal/mesh/internal/controllers/routes/controller_test.go index 72d2dce283..908f2ac8e0 100644 --- a/internal/mesh/internal/controllers/routes/controller_test.go +++ b/internal/mesh/internal/controllers/routes/controller_test.go @@ -12,9 +12,9 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - svctest "github.com/hashicorp/consul/agent/grpc-external/services/resource/testing" "github.com/hashicorp/consul/internal/catalog" "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/controller/controllertest" "github.com/hashicorp/consul/internal/mesh/internal/types" "github.com/hashicorp/consul/internal/resource" rtest "github.com/hashicorp/consul/internal/resource/resourcetest" @@ -31,7 +31,6 @@ type controllerSuite struct { ctx context.Context client *rtest.Client - rt controller.Runtime tenancies []*pbresource.Tenancy refs *testResourceRef @@ -46,30 +45,20 @@ type testResourceRef struct { func (suite *controllerSuite) SetupTest() { suite.ctx = testutil.TestContext(suite.T()) suite.tenancies = rtest.TestTenancies() - client := svctest.NewResourceServiceBuilder(). - WithRegisterFns(types.Register, catalog.RegisterTypes). + + client := controllertest.NewControllerTestBuilder(). WithTenancies(suite.tenancies...). + WithResourceRegisterFns(types.Register, catalog.RegisterTypes). + WithControllerRegisterFns(func(mgr *controller.Manager) { + mgr.Register(Controller()) + }). Run(suite.T()) - suite.rt = controller.Runtime{ - Client: client, - Logger: testutil.Logger(suite.T()), - } suite.client = rtest.NewClient(client) } func (suite *controllerSuite) TestController() { - mgr := controller.NewManager(suite.client, suite.rt.Logger) - mgr.Register(Controller()) - mgr.SetRaftLeader(true) - go mgr.Run(suite.ctx) - suite.runTestCaseWithTenancies(func(refs *testResourceRef) { - - backendName := func(name, port string, tenancy *pbresource.Tenancy) string { - return fmt.Sprintf("catalog.v2beta1.Service/%s.local.%s/%s?port=%s", tenancy.Partition, tenancy.Namespace, name, port) - } - var ( apiServiceRef = refs.apiServiceRef fooServiceRef = refs.fooServiceRef @@ -136,7 +125,6 @@ func (suite *controllerSuite) TestController() { }) // Let the default http/http2/grpc routes get created. - apiServiceData = &pbcatalog.Service{ Workloads: &pbcatalog.WorkloadSelector{ Prefixes: []string{"api-"}, @@ -1428,3 +1416,7 @@ func (suite *controllerSuite) runTestCaseWithTenancies(testFunc func(ref *testRe func (suite *controllerSuite) appendTenancyInfo(tenancy *pbresource.Tenancy) string { return fmt.Sprintf("%s_Namespace_%s_Partition", tenancy.Namespace, tenancy.Partition) } + +func backendName(name, port string, tenancy *pbresource.Tenancy) string { + return fmt.Sprintf("catalog.v2beta1.Service/%s.local.%s/%s?port=%s", tenancy.Partition, tenancy.Namespace, name, port) +} diff --git a/internal/mesh/internal/controllers/routes/loader/loader.go b/internal/mesh/internal/controllers/routes/loader/loader.go index 4e7be5e649..721137b5ae 100644 --- a/internal/mesh/internal/controllers/routes/loader/loader.go +++ b/internal/mesh/internal/controllers/routes/loader/loader.go @@ -224,7 +224,6 @@ func (l *loader) loadUpstreamService( return err } if failoverPolicy != nil { - l.mapper.TrackFailoverPolicy(failoverPolicy) l.out.AddFailoverPolicy(failoverPolicy) destRefs := failoverPolicy.Data.GetUnderlyingDestinationRefs() @@ -245,8 +244,6 @@ func (l *loader) loadUpstreamService( } } } - } else { - l.mapper.UntrackFailoverPolicy(failoverPolicyID) } if err := l.loadDestConfig(ctx, logger, svcID); err != nil { diff --git a/internal/mesh/internal/controllers/routes/loader/loader_test.go b/internal/mesh/internal/controllers/routes/loader/loader_test.go index 0db050fee6..6341e98f43 100644 --- a/internal/mesh/internal/controllers/routes/loader/loader_test.go +++ b/internal/mesh/internal/controllers/routes/loader/loader_test.go @@ -4,6 +4,7 @@ package loader import ( + "context" "testing" "time" @@ -15,6 +16,8 @@ import ( svctest "github.com/hashicorp/consul/agent/grpc-external/services/resource/testing" "github.com/hashicorp/consul/internal/catalog" "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/controller/cache" + "github.com/hashicorp/consul/internal/controller/cache/indexers" "github.com/hashicorp/consul/internal/mesh/internal/controllers/routes/xroutemapper" "github.com/hashicorp/consul/internal/mesh/internal/types" "github.com/hashicorp/consul/internal/resource" @@ -27,21 +30,41 @@ import ( ) func TestLoadResourcesForComputedRoutes(t *testing.T) { + // temporarily creating the cache here until we can get rid of the xroutemapper object entirely. Its not super clean to hack together a cache for usage in this func + // but its better than alternatives and this should be relatively short lived. + testCache := cache.New() + testCache.AddIndex(pbcatalog.FailoverPolicyType, indexers.RefOrIDIndex("dest-refs", func(res *resource.DecodedResource[*pbcatalog.FailoverPolicy]) []*pbresource.Reference { + return res.Data.GetUnderlyingDestinationRefs() + })) + ctx := testutil.TestContext(t) rclient := svctest.NewResourceServiceBuilder(). WithRegisterFns(types.Register, catalog.RegisterTypes). Run(t) rt := controller.Runtime{ - Client: rclient, + Client: cache.NewCachedClient(testCache, rclient), Logger: testutil.Logger(t), } - client := rtest.NewClient(rclient) + + client := rtest.NewClient(rt.Client) loggerFor := func(id *pbresource.ID) hclog.Logger { return rt.Logger.With("resource-id", id) } - mapper := xroutemapper.New() + mapper := xroutemapper.New(func(_ context.Context, rt controller.Runtime, id *pbresource.ID) ([]*pbresource.ID, error) { + iter, err := rt.Cache.ListIterator(pbcatalog.FailoverPolicyType, "dest-refs", id) + if err != nil { + return nil, err + } + + var resolved []*pbresource.ID + for res := iter.Next(); res != nil; res = iter.Next() { + resolved = append(resolved, resource.ReplaceType(pbcatalog.ServiceType, res.Id)) + } + + return resolved, nil + }) deleteRes := func(id *pbresource.ID, untrack bool) { client.MustDelete(t, id) @@ -49,8 +72,6 @@ func TestLoadResourcesForComputedRoutes(t *testing.T) { switch { case types.IsRouteType(id.Type): mapper.UntrackXRoute(id) - case types.IsFailoverPolicyType(id.Type): - mapper.UntrackFailoverPolicy(id) } } } diff --git a/internal/mesh/internal/controllers/routes/xroutemapper/.mockery.yaml b/internal/mesh/internal/controllers/routes/xroutemapper/.mockery.yaml new file mode 100644 index 0000000000..d4e2a2df70 --- /dev/null +++ b/internal/mesh/internal/controllers/routes/xroutemapper/.mockery.yaml @@ -0,0 +1,15 @@ +# Copyright (c) HashiCorp, Inc. +# SPDX-License-Identifier: BUSL-1.1 + +with-expecter: true +recursive: false +all: true +# We don't want the mocks within proto-public so as to force a dependency +# of the testify library on the modules usage. The mocks are only for +# internal testing purposes. Other consumers can generated the mocks into +# their own code base. +dir: "{{.PackageName}}mock" +outpkg: "{{.PackageName}}mock" +mockname: "{{.InterfaceName}}" +packages: + github.com/hashicorp/consul/internal/mesh/internal/controllers/routes/xroutemapper: diff --git a/internal/mesh/internal/controllers/routes/xroutemapper/xroutemapper.go b/internal/mesh/internal/controllers/routes/xroutemapper/xroutemapper.go index 2c240ee718..12e4ec8b35 100644 --- a/internal/mesh/internal/controllers/routes/xroutemapper/xroutemapper.go +++ b/internal/mesh/internal/controllers/routes/xroutemapper/xroutemapper.go @@ -7,7 +7,6 @@ import ( "context" "fmt" - "github.com/hashicorp/consul/internal/catalog" "github.com/hashicorp/consul/internal/controller" "github.com/hashicorp/consul/internal/mesh/internal/types" "github.com/hashicorp/consul/internal/resource" @@ -17,6 +16,8 @@ import ( "github.com/hashicorp/consul/proto-public/pbresource" ) +type ResolveFailoverServiceDestinations func(context.Context, controller.Runtime, *pbresource.ID) ([]*pbresource.ID, error) + // Mapper tracks the following relationships: // // - xRoute <-> ParentRef Service @@ -39,11 +40,14 @@ type Mapper struct { grpcRouteBackendMapper *bimapper.Mapper tcpRouteBackendMapper *bimapper.Mapper - failMapper catalog.FailoverPolicyMapper + resolveFailoverServiceDestinations ResolveFailoverServiceDestinations } // New creates a new Mapper. -func New() *Mapper { +func New(resolver ResolveFailoverServiceDestinations) *Mapper { + if resolver == nil { + panic("must specify a ResolveFailoverServiceDestinations callback") + } return &Mapper{ boundRefMapper: bimapper.NewWithWildcardLinkType(pbmesh.ComputedRoutesType), @@ -55,7 +59,7 @@ func New() *Mapper { grpcRouteBackendMapper: bimapper.New(pbmesh.GRPCRouteType, pbcatalog.ServiceType), tcpRouteBackendMapper: bimapper.New(pbmesh.TCPRouteType, pbcatalog.ServiceType), - failMapper: catalog.NewFailoverPolicyMapper(), + resolveFailoverServiceDestinations: resolver, } } @@ -207,48 +211,11 @@ func mapXRouteToComputedRoutes[T types.XRouteData](res *pbresource.Resource, m * return controller.MakeRequests(pbmesh.ComputedRoutesType, refs), nil } -func (m *Mapper) MapFailoverPolicy( +func (m *Mapper) MapServiceNameAligned( _ context.Context, _ controller.Runtime, res *pbresource.Resource, ) ([]controller.Request, error) { - if !types.IsFailoverPolicyType(res.Id.Type) { - return nil, fmt.Errorf("type is not a failover policy type: %s", res.Id.Type) - } - - dec, err := resource.Decode[*pbcatalog.FailoverPolicy](res) - if err != nil { - return nil, fmt.Errorf("error unmarshalling failover policy: %w", err) - } - - m.failMapper.TrackFailover(dec) - - // Since this is name-aligned, just switch the type and find routes that - // will route any traffic to this destination service. - svcID := resource.ReplaceType(pbcatalog.ServiceType, res.Id) - - return m.mapXRouteDirectServiceRefToComputedRoutesByID(svcID) -} - -func (m *Mapper) TrackFailoverPolicy(failover *types.DecodedFailoverPolicy) { - if failover != nil { - m.failMapper.TrackFailover(failover) - } -} - -func (m *Mapper) UntrackFailoverPolicy(failoverPolicyID *pbresource.ID) { - m.failMapper.UntrackFailover(failoverPolicyID) -} - -func (m *Mapper) MapDestinationPolicy( - _ context.Context, - _ controller.Runtime, - res *pbresource.Resource, -) ([]controller.Request, error) { - if !types.IsDestinationPolicyType(res.Id.Type) { - return nil, fmt.Errorf("type is not a destination policy type: %s", res.Id.Type) - } - // Since this is name-aligned, just switch the type and find routes that // will route any traffic to this destination service. svcID := resource.ReplaceType(pbcatalog.ServiceType, res.Id) @@ -257,8 +224,8 @@ func (m *Mapper) MapDestinationPolicy( } func (m *Mapper) MapService( - _ context.Context, - _ controller.Runtime, + ctx context.Context, + rt controller.Runtime, res *pbresource.Resource, ) ([]controller.Request, error) { // Ultimately we want to wake up a ComputedRoutes if either of the @@ -268,8 +235,10 @@ func (m *Mapper) MapService( // 2. xRoute[parentRef=OUTPUT_EVENT; backendRef=SOMETHING], FailoverPolicy[name=SOMETHING, destRef=INPUT_EVENT] // (case 2) First find all failover policies that have a reference to our input service. - failPolicyIDs := m.failMapper.FailoverIDsByService(res.Id) - effectiveServiceIDs := sliceReplaceType(failPolicyIDs, pbcatalog.ServiceType) + effectiveServiceIDs, err := m.resolveFailoverServiceDestinations(ctx, rt, res.Id) + if err != nil { + return nil, err + } // (case 1) Do the direct mapping also. effectiveServiceIDs = append(effectiveServiceIDs, res.Id) diff --git a/internal/mesh/internal/controllers/routes/xroutemapper/xroutemapper_test.go b/internal/mesh/internal/controllers/routes/xroutemapper/xroutemapper_test.go index e74b049b7d..6b7bab9cf1 100644 --- a/internal/mesh/internal/controllers/routes/xroutemapper/xroutemapper_test.go +++ b/internal/mesh/internal/controllers/routes/xroutemapper/xroutemapper_test.go @@ -16,6 +16,8 @@ import ( "github.com/hashicorp/consul/internal/catalog" "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/controller/cache" + "github.com/hashicorp/consul/internal/controller/cache/indexers" "github.com/hashicorp/consul/internal/mesh/internal/types" "github.com/hashicorp/consul/internal/resource" rtest "github.com/hashicorp/consul/internal/resource/resourcetest" @@ -78,6 +80,27 @@ func testMapper_Tracking(t *testing.T, typ *pbresource.Type, newRoute func(t *te types.Register(registry) catalog.RegisterTypes(registry) + // temporarily creating the cache here until we can get rid of this xroutemapper object entirely. Its not super clean to hack together a cache for usage in this func + // but its better than alternatives and this should be relatively short lived. + testCache := cache.New() + testCache.AddIndex(pbcatalog.FailoverPolicyType, indexers.RefOrIDIndex("dest-refs", func(res *resource.DecodedResource[*pbcatalog.FailoverPolicy]) []*pbresource.Reference { + return res.Data.GetUnderlyingDestinationRefs() + })) + + m := New(func(_ context.Context, rt controller.Runtime, id *pbresource.ID) ([]*pbresource.ID, error) { + iter, err := rt.Cache.ListIterator(pbcatalog.FailoverPolicyType, "dest-refs", id) + if err != nil { + return nil, err + } + + var resolved []*pbresource.ID + for res := iter.Next(); res != nil; res = iter.Next() { + resolved = append(resolved, resource.ReplaceType(pbcatalog.ServiceType, res.Id)) + } + + return resolved, nil + }) + newService := func(name string) *pbresource.Resource { svc := rtest.Resource(pbcatalog.ServiceType, name). WithTenancy(resource.DefaultNamespacedTenancy()). @@ -126,8 +149,6 @@ func testMapper_Tracking(t *testing.T, typ *pbresource.Type, newRoute func(t *te zimComputedRoutes := newID(pbmesh.ComputedRoutesType, "zim") girComputedRoutes := newID(pbmesh.ComputedRoutesType, "gir") - m := New() - var ( apiSvc = newService("api") wwwSvc = newService("www") @@ -153,20 +174,20 @@ func testMapper_Tracking(t *testing.T, typ *pbresource.Type, newRoute func(t *te ) testutil.RunStep(t, "only name aligned defaults", func(t *testing.T) { - requireTracking(t, m, apiSvc, apiComputedRoutes) - requireTracking(t, m, wwwSvc, wwwComputedRoutes) - requireTracking(t, m, barSvc, barComputedRoutes) - requireTracking(t, m, fooSvc, fooComputedRoutes) - requireTracking(t, m, zimSvc, zimComputedRoutes) - requireTracking(t, m, girSvc, girComputedRoutes) + requireTracking(t, m, testCache, apiSvc, apiComputedRoutes) + requireTracking(t, m, testCache, wwwSvc, wwwComputedRoutes) + requireTracking(t, m, testCache, barSvc, barComputedRoutes) + requireTracking(t, m, testCache, fooSvc, fooComputedRoutes) + requireTracking(t, m, testCache, zimSvc, zimComputedRoutes) + requireTracking(t, m, testCache, girSvc, girComputedRoutes) - requireTracking(t, m, apiDest, apiComputedRoutes) - requireTracking(t, m, wwwDest, wwwComputedRoutes) + requireTracking(t, m, testCache, apiDest, apiComputedRoutes) + requireTracking(t, m, testCache, wwwDest, wwwComputedRoutes) // This will track the failover policies. - requireTracking(t, m, apiFail, apiComputedRoutes) - requireTracking(t, m, wwwFail, wwwComputedRoutes) - requireTracking(t, m, barFail, barComputedRoutes) + requireTracking(t, m, testCache, apiFail, apiComputedRoutes) + requireTracking(t, m, testCache, wwwFail, wwwComputedRoutes) + requireTracking(t, m, testCache, barFail, barComputedRoutes) // verify other helper methods for _, ref := range []*pbresource.Reference{apiSvcRef, wwwSvcRef, barSvcRef, fooSvcRef, zimSvcRef, girSvcRef} { @@ -192,22 +213,22 @@ func testMapper_Tracking(t *testing.T, typ *pbresource.Type, newRoute func(t *te )).Build() rtest.ValidateAndNormalize(t, registry, route1) - requireTracking(t, m, route1, apiComputedRoutes) + requireTracking(t, m, testCache, route1, apiComputedRoutes) // Now 'api' references should trigger more, but be duplicate-suppressed. - requireTracking(t, m, apiSvc, apiComputedRoutes) - requireTracking(t, m, wwwSvc, wwwComputedRoutes) - requireTracking(t, m, barSvc, barComputedRoutes) - requireTracking(t, m, fooSvc, fooComputedRoutes) - requireTracking(t, m, zimSvc, zimComputedRoutes) - requireTracking(t, m, girSvc, girComputedRoutes) + requireTracking(t, m, testCache, apiSvc, apiComputedRoutes) + requireTracking(t, m, testCache, wwwSvc, wwwComputedRoutes) + requireTracking(t, m, testCache, barSvc, barComputedRoutes) + requireTracking(t, m, testCache, fooSvc, fooComputedRoutes) + requireTracking(t, m, testCache, zimSvc, zimComputedRoutes) + requireTracking(t, m, testCache, girSvc, girComputedRoutes) - requireTracking(t, m, apiDest, apiComputedRoutes) - requireTracking(t, m, wwwDest, wwwComputedRoutes) + requireTracking(t, m, testCache, apiDest, apiComputedRoutes) + requireTracking(t, m, testCache, wwwDest, wwwComputedRoutes) - requireTracking(t, m, apiFail, apiComputedRoutes) - requireTracking(t, m, wwwFail, wwwComputedRoutes) - requireTracking(t, m, barFail, barComputedRoutes) + requireTracking(t, m, testCache, apiFail, apiComputedRoutes) + requireTracking(t, m, testCache, wwwFail, wwwComputedRoutes) + requireTracking(t, m, testCache, barFail, barComputedRoutes) // verify other helper methods prototest.AssertElementsMatch(t, []*pbresource.Reference{apiSvcRef}, m.BackendServiceRefsByRouteID(route1.Id)) @@ -236,22 +257,22 @@ func testMapper_Tracking(t *testing.T, typ *pbresource.Type, newRoute func(t *te rtest.ValidateAndNormalize(t, registry, route1) // Now witness the update. - requireTracking(t, m, route1, apiComputedRoutes) + requireTracking(t, m, testCache, route1, apiComputedRoutes) // Now 'api' references should trigger different things. - requireTracking(t, m, apiSvc, apiComputedRoutes) - requireTracking(t, m, wwwSvc, wwwComputedRoutes, apiComputedRoutes) - requireTracking(t, m, barSvc, barComputedRoutes) - requireTracking(t, m, fooSvc, fooComputedRoutes) - requireTracking(t, m, zimSvc, zimComputedRoutes) - requireTracking(t, m, girSvc, girComputedRoutes) + requireTracking(t, m, testCache, apiSvc, apiComputedRoutes) + requireTracking(t, m, testCache, wwwSvc, wwwComputedRoutes, apiComputedRoutes) + requireTracking(t, m, testCache, barSvc, barComputedRoutes) + requireTracking(t, m, testCache, fooSvc, fooComputedRoutes) + requireTracking(t, m, testCache, zimSvc, zimComputedRoutes) + requireTracking(t, m, testCache, girSvc, girComputedRoutes) - requireTracking(t, m, apiDest, apiComputedRoutes) - requireTracking(t, m, wwwDest, wwwComputedRoutes, apiComputedRoutes) + requireTracking(t, m, testCache, apiDest, apiComputedRoutes) + requireTracking(t, m, testCache, wwwDest, wwwComputedRoutes, apiComputedRoutes) - requireTracking(t, m, apiFail, apiComputedRoutes) - requireTracking(t, m, wwwFail, wwwComputedRoutes, apiComputedRoutes) - requireTracking(t, m, barFail, barComputedRoutes) + requireTracking(t, m, testCache, apiFail, apiComputedRoutes) + requireTracking(t, m, testCache, wwwFail, wwwComputedRoutes, apiComputedRoutes) + requireTracking(t, m, testCache, barFail, barComputedRoutes) // verify other helper methods prototest.AssertElementsMatch(t, []*pbresource.Reference{wwwSvcRef}, m.BackendServiceRefsByRouteID(route1.Id)) @@ -287,26 +308,26 @@ func testMapper_Tracking(t *testing.T, typ *pbresource.Type, newRoute func(t *te rtest.ValidateAndNormalize(t, registry, route1) // Now witness a route with multiple parents, overlapping the other route. - requireTracking(t, m, route2, apiComputedRoutes, fooComputedRoutes) + requireTracking(t, m, testCache, route2, apiComputedRoutes, fooComputedRoutes) - requireTracking(t, m, apiSvc, apiComputedRoutes) - requireTracking(t, m, wwwSvc, wwwComputedRoutes, apiComputedRoutes) - requireTracking(t, m, barSvc, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) + requireTracking(t, m, testCache, apiSvc, apiComputedRoutes) + requireTracking(t, m, testCache, wwwSvc, wwwComputedRoutes, apiComputedRoutes) + requireTracking(t, m, testCache, barSvc, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) - requireTracking(t, m, fooSvc, fooComputedRoutes) - requireTracking(t, m, zimSvc, zimComputedRoutes) - requireTracking(t, m, girSvc, girComputedRoutes) + requireTracking(t, m, testCache, fooSvc, fooComputedRoutes) + requireTracking(t, m, testCache, zimSvc, zimComputedRoutes) + requireTracking(t, m, testCache, girSvc, girComputedRoutes) - requireTracking(t, m, apiDest, apiComputedRoutes) - requireTracking(t, m, wwwDest, wwwComputedRoutes, apiComputedRoutes) + requireTracking(t, m, testCache, apiDest, apiComputedRoutes) + requireTracking(t, m, testCache, wwwDest, wwwComputedRoutes, apiComputedRoutes) - requireTracking(t, m, apiFail, apiComputedRoutes) - requireTracking(t, m, wwwFail, wwwComputedRoutes, apiComputedRoutes) - requireTracking(t, m, barFail, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) + requireTracking(t, m, testCache, apiFail, apiComputedRoutes) + requireTracking(t, m, testCache, wwwFail, wwwComputedRoutes, apiComputedRoutes) + requireTracking(t, m, testCache, barFail, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) - requireTracking(t, m, route1, apiComputedRoutes) + requireTracking(t, m, testCache, route1, apiComputedRoutes) // skip re-verifying route2 - // requireTracking(t, m, route2, apiComputedRoutes, fooComputedRoutes) + // requireTracking(t, m, rt, route2, apiComputedRoutes, fooComputedRoutes) // verify other helper methods prototest.AssertElementsMatch(t, []*pbresource.Reference{wwwSvcRef}, m.BackendServiceRefsByRouteID(route1.Id)) @@ -337,26 +358,26 @@ func testMapper_Tracking(t *testing.T, typ *pbresource.Type, newRoute func(t *te apiFail = newFailPolicy("api", newRef(pbcatalog.ServiceType, "foo"), newRef(pbcatalog.ServiceType, "zim")) - requireTracking(t, m, apiFail, apiComputedRoutes) + requireTracking(t, m, testCache, apiFail, apiComputedRoutes) - requireTracking(t, m, apiSvc, apiComputedRoutes) - requireTracking(t, m, wwwSvc, wwwComputedRoutes, apiComputedRoutes) - requireTracking(t, m, barSvc, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) + requireTracking(t, m, testCache, apiSvc, apiComputedRoutes) + requireTracking(t, m, testCache, wwwSvc, wwwComputedRoutes, apiComputedRoutes) + requireTracking(t, m, testCache, barSvc, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) - requireTracking(t, m, fooSvc, fooComputedRoutes, apiComputedRoutes) - requireTracking(t, m, zimSvc, zimComputedRoutes, apiComputedRoutes) - requireTracking(t, m, girSvc, girComputedRoutes) + requireTracking(t, m, testCache, fooSvc, fooComputedRoutes, apiComputedRoutes) + requireTracking(t, m, testCache, zimSvc, zimComputedRoutes, apiComputedRoutes) + requireTracking(t, m, testCache, girSvc, girComputedRoutes) - requireTracking(t, m, apiDest, apiComputedRoutes) - requireTracking(t, m, wwwDest, wwwComputedRoutes, apiComputedRoutes) + requireTracking(t, m, testCache, apiDest, apiComputedRoutes) + requireTracking(t, m, testCache, wwwDest, wwwComputedRoutes, apiComputedRoutes) // skipping verification of apiFail b/c it happened above already - // requireTracking(t, m, apiFail, apiComputedRoutes) - requireTracking(t, m, wwwFail, wwwComputedRoutes, apiComputedRoutes) - requireTracking(t, m, barFail, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) + // requireTracking(t, m, rt, apiFail, apiComputedRoutes) + requireTracking(t, m, testCache, wwwFail, wwwComputedRoutes, apiComputedRoutes) + requireTracking(t, m, testCache, barFail, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) - requireTracking(t, m, route1, apiComputedRoutes) - requireTracking(t, m, route2, apiComputedRoutes, fooComputedRoutes) + requireTracking(t, m, testCache, route1, apiComputedRoutes) + requireTracking(t, m, testCache, route2, apiComputedRoutes, fooComputedRoutes) // verify other helper methods prototest.AssertElementsMatch(t, []*pbresource.Reference{wwwSvcRef}, m.BackendServiceRefsByRouteID(route1.Id)) @@ -386,26 +407,26 @@ func testMapper_Tracking(t *testing.T, typ *pbresource.Type, newRoute func(t *te testutil.RunStep(t, "set a new failover policy for a service in route2", func(t *testing.T) { barFail = newFailPolicy("bar", newRef(pbcatalog.ServiceType, "gir")) - requireTracking(t, m, barFail, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) + requireTracking(t, m, testCache, barFail, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) - requireTracking(t, m, apiSvc, apiComputedRoutes) - requireTracking(t, m, wwwSvc, wwwComputedRoutes, apiComputedRoutes) - requireTracking(t, m, barSvc, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) + requireTracking(t, m, testCache, apiSvc, apiComputedRoutes) + requireTracking(t, m, testCache, wwwSvc, wwwComputedRoutes, apiComputedRoutes) + requireTracking(t, m, testCache, barSvc, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) - requireTracking(t, m, fooSvc, fooComputedRoutes, apiComputedRoutes) - requireTracking(t, m, zimSvc, zimComputedRoutes, apiComputedRoutes) - requireTracking(t, m, girSvc, girComputedRoutes, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) + requireTracking(t, m, testCache, fooSvc, fooComputedRoutes, apiComputedRoutes) + requireTracking(t, m, testCache, zimSvc, zimComputedRoutes, apiComputedRoutes) + requireTracking(t, m, testCache, girSvc, girComputedRoutes, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) - requireTracking(t, m, apiDest, apiComputedRoutes) - requireTracking(t, m, wwwDest, wwwComputedRoutes, apiComputedRoutes) + requireTracking(t, m, testCache, apiDest, apiComputedRoutes) + requireTracking(t, m, testCache, wwwDest, wwwComputedRoutes, apiComputedRoutes) - requireTracking(t, m, apiFail, apiComputedRoutes) - requireTracking(t, m, wwwFail, wwwComputedRoutes, apiComputedRoutes) + requireTracking(t, m, testCache, apiFail, apiComputedRoutes) + requireTracking(t, m, testCache, wwwFail, wwwComputedRoutes, apiComputedRoutes) // skipping verification of barFail b/c it happened above already - // requireTracking(t, m, barFail, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) + // requireTracking(t, m, rt, barFail, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) - requireTracking(t, m, route1, apiComputedRoutes) - requireTracking(t, m, route2, apiComputedRoutes, fooComputedRoutes) + requireTracking(t, m, testCache, route1, apiComputedRoutes) + requireTracking(t, m, testCache, route2, apiComputedRoutes, fooComputedRoutes) // verify other helper methods prototest.AssertElementsMatch(t, []*pbresource.Reference{wwwSvcRef}, m.BackendServiceRefsByRouteID(route1.Id)) @@ -436,22 +457,22 @@ func testMapper_Tracking(t *testing.T, typ *pbresource.Type, newRoute func(t *te m.UntrackXRoute(route1.Id) route1 = nil - requireTracking(t, m, apiSvc, apiComputedRoutes) - requireTracking(t, m, wwwSvc, wwwComputedRoutes) - requireTracking(t, m, barSvc, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) + requireTracking(t, m, testCache, apiSvc, apiComputedRoutes) + requireTracking(t, m, testCache, wwwSvc, wwwComputedRoutes) + requireTracking(t, m, testCache, barSvc, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) - requireTracking(t, m, fooSvc, fooComputedRoutes, apiComputedRoutes) - requireTracking(t, m, zimSvc, zimComputedRoutes, apiComputedRoutes) - requireTracking(t, m, girSvc, girComputedRoutes, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) + requireTracking(t, m, testCache, fooSvc, fooComputedRoutes, apiComputedRoutes) + requireTracking(t, m, testCache, zimSvc, zimComputedRoutes, apiComputedRoutes) + requireTracking(t, m, testCache, girSvc, girComputedRoutes, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) - requireTracking(t, m, apiDest, apiComputedRoutes) - requireTracking(t, m, wwwDest, wwwComputedRoutes) + requireTracking(t, m, testCache, apiDest, apiComputedRoutes) + requireTracking(t, m, testCache, wwwDest, wwwComputedRoutes) - requireTracking(t, m, apiFail, apiComputedRoutes) - requireTracking(t, m, wwwFail, wwwComputedRoutes) - requireTracking(t, m, barFail, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) + requireTracking(t, m, testCache, apiFail, apiComputedRoutes) + requireTracking(t, m, testCache, wwwFail, wwwComputedRoutes) + requireTracking(t, m, testCache, barFail, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) - requireTracking(t, m, route2, apiComputedRoutes, fooComputedRoutes) + requireTracking(t, m, testCache, route2, apiComputedRoutes, fooComputedRoutes) // verify other helper methods prototest.AssertElementsMatch(t, []*pbresource.Reference{barSvcRef}, m.BackendServiceRefsByRouteID(route2.Id)) @@ -473,26 +494,26 @@ func testMapper_Tracking(t *testing.T, typ *pbresource.Type, newRoute func(t *te }) testutil.RunStep(t, "delete all failover", func(t *testing.T) { - m.UntrackFailoverPolicy(apiFail.Id) - m.UntrackFailoverPolicy(wwwFail.Id) - m.UntrackFailoverPolicy(barFail.Id) + testCache.Delete(apiFail) + testCache.Delete(wwwFail) + testCache.Delete(barFail) apiFail = nil wwwFail = nil barFail = nil - requireTracking(t, m, apiSvc, apiComputedRoutes) - requireTracking(t, m, wwwSvc, wwwComputedRoutes) - requireTracking(t, m, barSvc, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) + requireTracking(t, m, testCache, apiSvc, apiComputedRoutes) + requireTracking(t, m, testCache, wwwSvc, wwwComputedRoutes) + requireTracking(t, m, testCache, barSvc, barComputedRoutes, apiComputedRoutes, fooComputedRoutes) - requireTracking(t, m, fooSvc, fooComputedRoutes) - requireTracking(t, m, zimSvc, zimComputedRoutes) - requireTracking(t, m, girSvc, girComputedRoutes) + requireTracking(t, m, testCache, fooSvc, fooComputedRoutes) + requireTracking(t, m, testCache, zimSvc, zimComputedRoutes) + requireTracking(t, m, testCache, girSvc, girComputedRoutes) - requireTracking(t, m, apiDest, apiComputedRoutes) - requireTracking(t, m, wwwDest, wwwComputedRoutes) + requireTracking(t, m, testCache, apiDest, apiComputedRoutes) + requireTracking(t, m, testCache, wwwDest, wwwComputedRoutes) - requireTracking(t, m, route2, apiComputedRoutes, fooComputedRoutes) + requireTracking(t, m, testCache, route2, apiComputedRoutes, fooComputedRoutes) // verify other helper methods prototest.AssertElementsMatch(t, []*pbresource.Reference{barSvcRef}, m.BackendServiceRefsByRouteID(route2.Id)) @@ -517,16 +538,16 @@ func testMapper_Tracking(t *testing.T, typ *pbresource.Type, newRoute func(t *te m.UntrackXRoute(route2.Id) route2 = nil - requireTracking(t, m, apiSvc, apiComputedRoutes) - requireTracking(t, m, wwwSvc, wwwComputedRoutes) - requireTracking(t, m, barSvc, barComputedRoutes) + requireTracking(t, m, testCache, apiSvc, apiComputedRoutes) + requireTracking(t, m, testCache, wwwSvc, wwwComputedRoutes) + requireTracking(t, m, testCache, barSvc, barComputedRoutes) - requireTracking(t, m, fooSvc, fooComputedRoutes) - requireTracking(t, m, zimSvc, zimComputedRoutes) - requireTracking(t, m, girSvc, girComputedRoutes) + requireTracking(t, m, testCache, fooSvc, fooComputedRoutes) + requireTracking(t, m, testCache, zimSvc, zimComputedRoutes) + requireTracking(t, m, testCache, girSvc, girComputedRoutes) - requireTracking(t, m, apiDest, apiComputedRoutes) - requireTracking(t, m, wwwDest, wwwComputedRoutes) + requireTracking(t, m, testCache, apiDest, apiComputedRoutes) + requireTracking(t, m, testCache, wwwDest, wwwComputedRoutes) // verify other helper methods for _, ref := range []*pbresource.Reference{apiSvcRef, wwwSvcRef, barSvcRef, fooSvcRef, zimSvcRef, girSvcRef} { @@ -549,7 +570,7 @@ func testMapper_Tracking(t *testing.T, typ *pbresource.Type, newRoute func(t *te )).Build() rtest.ValidateAndNormalize(t, registry, route1) - requireTracking(t, m, route1, barComputedRoutes, fooComputedRoutes) + requireTracking(t, m, testCache, route1, barComputedRoutes, fooComputedRoutes) // Simulate a Reconcile that would update the mapper. // @@ -577,7 +598,7 @@ func testMapper_Tracking(t *testing.T, typ *pbresource.Type, newRoute func(t *te m.TrackComputedRoutes(rtest.MustDecode[*pbmesh.ComputedRoutes](t, barCR)) // Still has the same tracking. - requireTracking(t, m, route1, barComputedRoutes, fooComputedRoutes) + requireTracking(t, m, testCache, route1, barComputedRoutes, fooComputedRoutes) // Now change the route to remove "bar" @@ -594,7 +615,7 @@ func testMapper_Tracking(t *testing.T, typ *pbresource.Type, newRoute func(t *te rtest.ValidateAndNormalize(t, registry, route1) // Now we see that it still emits the event for bar, so we get a chance to update it. - requireTracking(t, m, route1, barComputedRoutes, fooComputedRoutes) + requireTracking(t, m, testCache, route1, barComputedRoutes, fooComputedRoutes) // Update the bound references on 'bar' to remove the route barCR = rtest.ResourceID(barComputedRoutes). @@ -608,13 +629,14 @@ func testMapper_Tracking(t *testing.T, typ *pbresource.Type, newRoute func(t *te m.TrackComputedRoutes(rtest.MustDecode[*pbmesh.ComputedRoutes](t, barCR)) // Now 'bar' no longer has a link to the route. - requireTracking(t, m, route1, fooComputedRoutes) + requireTracking(t, m, testCache, route1, fooComputedRoutes) }) } func requireTracking( t *testing.T, mapper *Mapper, + c cache.Cache, res *pbresource.Resource, computedRoutesIDs ...*pbresource.ID, ) { @@ -626,19 +648,24 @@ func requireTracking( reqs []controller.Request err error ) + + rt := controller.Runtime{ + Cache: c, + } switch { case resource.EqualType(pbmesh.HTTPRouteType, res.Id.Type): - reqs, err = mapper.MapHTTPRoute(context.Background(), controller.Runtime{}, res) + reqs, err = mapper.MapHTTPRoute(context.Background(), rt, res) case resource.EqualType(pbmesh.GRPCRouteType, res.Id.Type): - reqs, err = mapper.MapGRPCRoute(context.Background(), controller.Runtime{}, res) + reqs, err = mapper.MapGRPCRoute(context.Background(), rt, res) case resource.EqualType(pbmesh.TCPRouteType, res.Id.Type): - reqs, err = mapper.MapTCPRoute(context.Background(), controller.Runtime{}, res) + reqs, err = mapper.MapTCPRoute(context.Background(), rt, res) case resource.EqualType(pbmesh.DestinationPolicyType, res.Id.Type): - reqs, err = mapper.MapDestinationPolicy(context.Background(), controller.Runtime{}, res) + reqs, err = mapper.MapServiceNameAligned(context.Background(), rt, res) case resource.EqualType(pbcatalog.FailoverPolicyType, res.Id.Type): - reqs, err = mapper.MapFailoverPolicy(context.Background(), controller.Runtime{}, res) + c.Insert(res) + reqs, err = mapper.MapServiceNameAligned(context.Background(), rt, res) case resource.EqualType(pbcatalog.ServiceType, res.Id.Type): - reqs, err = mapper.MapService(context.Background(), controller.Runtime{}, res) + reqs, err = mapper.MapService(context.Background(), rt, res) default: t.Fatalf("unhandled resource type: %s", resource.TypeToString(res.Id.Type)) } diff --git a/internal/mesh/internal/controllers/routes/xroutemapper/xroutemappermock/mock_ResolveFailoverServiceDestinations.go b/internal/mesh/internal/controllers/routes/xroutemapper/xroutemappermock/mock_ResolveFailoverServiceDestinations.go new file mode 100644 index 0000000000..1f0c3f41ac --- /dev/null +++ b/internal/mesh/internal/controllers/routes/xroutemapper/xroutemappermock/mock_ResolveFailoverServiceDestinations.go @@ -0,0 +1,95 @@ +// Code generated by mockery v2.37.1. DO NOT EDIT. + +package xroutemappermock + +import ( + context "context" + + controller "github.com/hashicorp/consul/internal/controller" + mock "github.com/stretchr/testify/mock" + + pbresource "github.com/hashicorp/consul/proto-public/pbresource" +) + +// ResolveFailoverServiceDestinations is an autogenerated mock type for the ResolveFailoverServiceDestinations type +type ResolveFailoverServiceDestinations struct { + mock.Mock +} + +type ResolveFailoverServiceDestinations_Expecter struct { + mock *mock.Mock +} + +func (_m *ResolveFailoverServiceDestinations) EXPECT() *ResolveFailoverServiceDestinations_Expecter { + return &ResolveFailoverServiceDestinations_Expecter{mock: &_m.Mock} +} + +// Execute provides a mock function with given fields: _a0, _a1, _a2 +func (_m *ResolveFailoverServiceDestinations) Execute(_a0 context.Context, _a1 controller.Runtime, _a2 *pbresource.ID) ([]*pbresource.ID, error) { + ret := _m.Called(_a0, _a1, _a2) + + var r0 []*pbresource.ID + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, controller.Runtime, *pbresource.ID) ([]*pbresource.ID, error)); ok { + return rf(_a0, _a1, _a2) + } + if rf, ok := ret.Get(0).(func(context.Context, controller.Runtime, *pbresource.ID) []*pbresource.ID); ok { + r0 = rf(_a0, _a1, _a2) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*pbresource.ID) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, controller.Runtime, *pbresource.ID) error); ok { + r1 = rf(_a0, _a1, _a2) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ResolveFailoverServiceDestinations_Execute_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Execute' +type ResolveFailoverServiceDestinations_Execute_Call struct { + *mock.Call +} + +// Execute is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 controller.Runtime +// - _a2 *pbresource.ID +func (_e *ResolveFailoverServiceDestinations_Expecter) Execute(_a0 interface{}, _a1 interface{}, _a2 interface{}) *ResolveFailoverServiceDestinations_Execute_Call { + return &ResolveFailoverServiceDestinations_Execute_Call{Call: _e.mock.On("Execute", _a0, _a1, _a2)} +} + +func (_c *ResolveFailoverServiceDestinations_Execute_Call) Run(run func(_a0 context.Context, _a1 controller.Runtime, _a2 *pbresource.ID)) *ResolveFailoverServiceDestinations_Execute_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(controller.Runtime), args[2].(*pbresource.ID)) + }) + return _c +} + +func (_c *ResolveFailoverServiceDestinations_Execute_Call) Return(_a0 []*pbresource.ID, _a1 error) *ResolveFailoverServiceDestinations_Execute_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *ResolveFailoverServiceDestinations_Execute_Call) RunAndReturn(run func(context.Context, controller.Runtime, *pbresource.ID) ([]*pbresource.ID, error)) *ResolveFailoverServiceDestinations_Execute_Call { + _c.Call.Return(run) + return _c +} + +// NewResolveFailoverServiceDestinations creates a new instance of ResolveFailoverServiceDestinations. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewResolveFailoverServiceDestinations(t interface { + mock.TestingT + Cleanup(func()) +}) *ResolveFailoverServiceDestinations { + mock := &ResolveFailoverServiceDestinations{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +}