diff --git a/agent/consul/server.go b/agent/consul/server.go index 3ec392daae..73df01e858 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -507,7 +507,6 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server, incom incomingRPCLimiter: incomingRPCLimiter, routineManager: routine.NewManager(logger.Named(logging.ConsulServer)), typeRegistry: resource.NewRegistry(), - controllerManager: controller.NewManager(logger.Named(logging.ControllerRuntime)), } incomingRPCLimiter.Register(s) @@ -783,6 +782,17 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server, incom // to enable RPC forwarding. s.grpcHandler = newGRPCHandlerFromConfig(flat, config, s) s.grpcLeaderForwarder = flat.LeaderForwarder + + if err := s.setupInternalResourceService(logger); err != nil { + return nil, err + } + s.controllerManager = controller.NewManager( + s.internalResourceServiceClient, + logger.Named(logging.ControllerRuntime), + ) + s.registerResources() + go s.controllerManager.Run(&lib.StopChannelContext{StopCh: shutdownCh}) + go s.trackLeaderChanges() s.xdsCapacityController = xdscapacity.NewController(xdscapacity.Config{ @@ -792,10 +802,6 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server, incom }) go s.xdsCapacityController.Run(&lib.StopChannelContext{StopCh: s.shutdownCh}) - if err := s.setupInternalResourceService(logger); err != nil { - return nil, err - } - // Initialize Autopilot. This must happen before starting leadership monitoring // as establishing leadership could attempt to use autopilot and cause a panic. s.initAutopilot(config) @@ -832,9 +838,6 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server, incom return nil, err } - s.registerResources() - go s.controllerManager.Run(&lib.StopChannelContext{StopCh: shutdownCh}) - return s, nil } diff --git a/internal/controller/api.go b/internal/controller/api.go index 8545d339a7..d258eb40d6 100644 --- a/internal/controller/api.go +++ b/internal/controller/api.go @@ -1,6 +1,17 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + package controller import ( + "context" + "fmt" + "strings" + "time" + + "github.com/hashicorp/go-hclog" + + "github.com/hashicorp/consul/internal/resource" "github.com/hashicorp/consul/proto-public/pbresource" ) @@ -9,6 +20,86 @@ func ForType(managedType *pbresource.Type) Controller { return Controller{managedType: managedType} } +// WithReconciler changes the controller's reconciler. +func (c Controller) WithReconciler(reconciler Reconciler) Controller { + if reconciler == nil { + panic("reconciler must not be nil") + } + + c.reconciler = reconciler + return c +} + +// WithWatch adds a watch on the given type/dependency to the controller. mapper +// will be called to determine which resources must be reconciled as a result of +// a watched resource changing. +func (c Controller) WithWatch(watchedType *pbresource.Type, mapper DependencyMapper) Controller { + if watchedType == nil { + panic("watchedType must not be nil") + } + + if mapper == nil { + panic("mapper must not be nil") + } + + c.watches = append(c.watches, watch{watchedType, mapper}) + return c +} + +// WithLogger changes the controller's logger. +func (c Controller) WithLogger(logger hclog.Logger) Controller { + if logger == nil { + panic("logger must not be nil") + } + + c.logger = logger + return c +} + +// WithBackoff changes the base and maximum backoff values for the controller's +// retry rate limiter. +func (c Controller) WithBackoff(base, max time.Duration) Controller { + c.baseBackoff = base + c.maxBackoff = max + return c +} + +// WithPlacement changes where and how many replicas of the controller will run. +// In the majority of cases, the default placement (one leader elected instance +// per cluster) is the most appropriate and you shouldn't need to override it. +func (c Controller) WithPlacement(placement Placement) Controller { + c.placement = placement + return c +} + +// String returns a textual description of the controller, useful for debugging. +func (c Controller) String() string { + watchedTypes := make([]string, len(c.watches)) + for idx, w := range c.watches { + watchedTypes[idx] = fmt.Sprintf("%q", resource.ToGVK(w.watchedType)) + } + base, max := c.backoff() + return fmt.Sprintf( + ", placement=%q>", + resource.ToGVK(c.managedType), + strings.Join(watchedTypes, ", "), + base, max, + c.placement, + ) +} + +func (c Controller) backoff() (time.Duration, time.Duration) { + base := c.baseBackoff + if base == 0 { + base = 5 * time.Millisecond + } + max := c.maxBackoff + if max == 0 { + max = 1000 * time.Second + } + return base, max +} + // Controller runs a reconciliation loop to respond to changes in resources and // their dependencies. It is heavily inspired by Kubernetes' controller pattern: // https://kubernetes.io/docs/concepts/architecture/controller/ @@ -17,4 +108,101 @@ func ForType(managedType *pbresource.Type) Controller { // a controller, and then pass it to a Manager to be executed. type Controller struct { managedType *pbresource.Type + reconciler Reconciler + logger hclog.Logger + watches []watch + baseBackoff time.Duration + maxBackoff time.Duration + placement Placement +} + +type watch struct { + watchedType *pbresource.Type + mapper DependencyMapper +} + +// Request represents a request to reconcile the resource with the given ID. +type Request struct { + // ID of the resource that needs to be reconciled. + ID *pbresource.ID +} + +// Runtime contains the dependencies required by reconcilers. +type Runtime struct { + Client pbresource.ResourceServiceClient + Logger hclog.Logger +} + +// Reconciler implements the business logic of a controller. +type Reconciler interface { + // Reconcile the resource identified by req.ID. + Reconcile(ctx context.Context, rt Runtime, req Request) error +} + +// DependencyMapper is called when a dependency watched via WithWatch is changed +// to determine which of the controller's managed resources need to be reconciled. +type DependencyMapper func( + ctx context.Context, + rt Runtime, + res *pbresource.Resource, +) ([]Request, error) + +// MapOwner implements a DependencyMapper that returns the updated resource's owner. +func MapOwner(_ context.Context, _ Runtime, res *pbresource.Resource) ([]Request, error) { + var reqs []Request + if res.Owner != nil { + reqs = append(reqs, Request{ID: res.Owner}) + } + return reqs, nil +} + +// Placement determines where and how many replicas of the controller will run. +type Placement int + +const ( + // PlacementSingleton ensures there is a single, leader-elected, instance of + // the controller running in the cluster at any time. It's the default and is + // suitable for most use-cases. + PlacementSingleton Placement = iota + + // PlacementEachServer ensures there is a replica of the controller running on + // each server in the cluster. It is useful for cases where the controller is + // responsible for applying some configuration resource to the server whenever + // it changes (e.g. rate-limit configuration). Generally, controllers in this + // placement mode should not modify resources. + PlacementEachServer +) + +// String satisfies the fmt.Stringer interface. +func (p Placement) String() string { + switch p { + case PlacementSingleton: + return "singleton" + case PlacementEachServer: + return "each-server" + } + panic(fmt.Sprintf("unknown placement %d", p)) +} + +// RequeueAfterError is an error that allows a Reconciler to override the +// exponential backoff behavior of the Controller, rather than applying +// the backoff algorithm, returning a RequeueAfterError will cause the +// Controller to reschedule the Request at a given time in the future. +type RequeueAfterError time.Duration + +// Error implements the error interface. +func (r RequeueAfterError) Error() string { + return fmt.Sprintf("requeue at %s", time.Duration(r)) +} + +// RequeueAfter constructs a RequeueAfterError with the given duration +// setting. +func RequeueAfter(after time.Duration) error { + return RequeueAfterError(after) +} + +// RequeueNow constructs a RequeueAfterError that reschedules the Request +// immediately. +func RequeueNow() error { + return RequeueAfterError(0) } diff --git a/internal/controller/api_test.go b/internal/controller/api_test.go new file mode 100644 index 0000000000..2006664b20 --- /dev/null +++ b/internal/controller/api_test.go @@ -0,0 +1,268 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package controller_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" + + svctest "github.com/hashicorp/consul/agent/grpc-external/services/resource/testing" + "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/resource/demo" + "github.com/hashicorp/consul/proto-public/pbresource" + "github.com/hashicorp/consul/proto/private/prototest" + "github.com/hashicorp/consul/sdk/testutil" +) + +func TestController_API(t *testing.T) { + t.Parallel() + + rec := newTestReconciler() + client := svctest.RunResourceService(t, demo.RegisterTypes) + + ctrl := controller. + ForType(demo.TypeV2Artist). + WithWatch(demo.TypeV2Album, controller.MapOwner). + WithBackoff(10*time.Millisecond, 100*time.Millisecond). + WithReconciler(rec) + + mgr := controller.NewManager(client, testutil.Logger(t)) + mgr.Register(ctrl) + mgr.SetRaftLeader(true) + go mgr.Run(testContext(t)) + + t.Run("managed resource type", func(t *testing.T) { + res, err := demo.GenerateV2Artist() + require.NoError(t, err) + + rsp, err := client.Write(testContext(t), &pbresource.WriteRequest{Resource: res}) + require.NoError(t, err) + + req := rec.wait(t) + prototest.AssertDeepEqual(t, rsp.Resource.Id, req.ID) + }) + + t.Run("watched resource type", func(t *testing.T) { + res, err := demo.GenerateV2Artist() + require.NoError(t, err) + + rsp, err := client.Write(testContext(t), &pbresource.WriteRequest{Resource: res}) + require.NoError(t, err) + + req := rec.wait(t) + prototest.AssertDeepEqual(t, rsp.Resource.Id, req.ID) + + rec.expectNoRequest(t, 500*time.Millisecond) + + album, err := demo.GenerateV2Album(rsp.Resource.Id) + require.NoError(t, err) + + _, err = client.Write(testContext(t), &pbresource.WriteRequest{Resource: album}) + require.NoError(t, err) + + req = rec.wait(t) + prototest.AssertDeepEqual(t, rsp.Resource.Id, req.ID) + }) + + t.Run("error retries", func(t *testing.T) { + rec.failNext(errors.New("KABOOM")) + + res, err := demo.GenerateV2Artist() + require.NoError(t, err) + + rsp, err := client.Write(testContext(t), &pbresource.WriteRequest{Resource: res}) + require.NoError(t, err) + + req := rec.wait(t) + prototest.AssertDeepEqual(t, rsp.Resource.Id, req.ID) + + // Reconciler should be called with the same request again. + req = rec.wait(t) + prototest.AssertDeepEqual(t, rsp.Resource.Id, req.ID) + }) + + t.Run("panic retries", func(t *testing.T) { + rec.panicNext("KABOOM") + + res, err := demo.GenerateV2Artist() + require.NoError(t, err) + + rsp, err := client.Write(testContext(t), &pbresource.WriteRequest{Resource: res}) + require.NoError(t, err) + + req := rec.wait(t) + prototest.AssertDeepEqual(t, rsp.Resource.Id, req.ID) + + // Reconciler should be called with the same request again. + req = rec.wait(t) + prototest.AssertDeepEqual(t, rsp.Resource.Id, req.ID) + }) + + t.Run("defer", func(t *testing.T) { + rec.failNext(controller.RequeueAfter(1 * time.Second)) + + res, err := demo.GenerateV2Artist() + require.NoError(t, err) + + rsp, err := client.Write(testContext(t), &pbresource.WriteRequest{Resource: res}) + require.NoError(t, err) + + req := rec.wait(t) + prototest.AssertDeepEqual(t, rsp.Resource.Id, req.ID) + + rec.expectNoRequest(t, 750*time.Millisecond) + + req = rec.wait(t) + prototest.AssertDeepEqual(t, rsp.Resource.Id, req.ID) + }) +} + +func TestController_Placement(t *testing.T) { + t.Parallel() + + t.Run("singleton", func(t *testing.T) { + rec := newTestReconciler() + client := svctest.RunResourceService(t, demo.RegisterTypes) + + ctrl := controller. + ForType(demo.TypeV2Artist). + WithWatch(demo.TypeV2Album, controller.MapOwner). + WithPlacement(controller.PlacementSingleton). + WithReconciler(rec) + + mgr := controller.NewManager(client, testutil.Logger(t)) + mgr.Register(ctrl) + go mgr.Run(testContext(t)) + + res, err := demo.GenerateV2Artist() + require.NoError(t, err) + + // Reconciler should not be called until we're the Raft leader. + _, err = client.Write(testContext(t), &pbresource.WriteRequest{Resource: res}) + require.NoError(t, err) + rec.expectNoRequest(t, 500*time.Millisecond) + + // Become the leader and check the reconciler is called. + mgr.SetRaftLeader(true) + _ = rec.wait(t) + + // Should not be called after losing leadership. + mgr.SetRaftLeader(false) + _, err = client.Write(testContext(t), &pbresource.WriteRequest{Resource: res}) + require.NoError(t, err) + rec.expectNoRequest(t, 500*time.Millisecond) + }) + + t.Run("each server", func(t *testing.T) { + rec := newTestReconciler() + client := svctest.RunResourceService(t, demo.RegisterTypes) + + ctrl := controller. + ForType(demo.TypeV2Artist). + WithWatch(demo.TypeV2Album, controller.MapOwner). + WithPlacement(controller.PlacementEachServer). + WithReconciler(rec) + + mgr := controller.NewManager(client, testutil.Logger(t)) + mgr.Register(ctrl) + go mgr.Run(testContext(t)) + + res, err := demo.GenerateV2Artist() + require.NoError(t, err) + + // Reconciler should be called even though we're not the Raft leader. + _, err = client.Write(testContext(t), &pbresource.WriteRequest{Resource: res}) + require.NoError(t, err) + _ = rec.wait(t) + }) +} + +func TestController_String(t *testing.T) { + ctrl := controller. + ForType(demo.TypeV2Artist). + WithWatch(demo.TypeV2Album, controller.MapOwner). + WithBackoff(5*time.Second, 1*time.Hour). + WithPlacement(controller.PlacementEachServer) + + require.Equal(t, + `, placement="each-server">`, + ctrl.String(), + ) +} + +func TestController_NoReconciler(t *testing.T) { + client := svctest.RunResourceService(t, demo.RegisterTypes) + mgr := controller.NewManager(client, testutil.Logger(t)) + + ctrl := controller.ForType(demo.TypeV2Artist) + require.PanicsWithValue(t, + `cannot register controller without a reconciler , placement="singleton">`, + func() { mgr.Register(ctrl) }) +} + +func newTestReconciler() *testReconciler { + return &testReconciler{ + calls: make(chan controller.Request), + errors: make(chan error, 1), + panics: make(chan any, 1), + } +} + +type testReconciler struct { + calls chan controller.Request + errors chan error + panics chan any +} + +func (r *testReconciler) Reconcile(_ context.Context, _ controller.Runtime, req controller.Request) error { + r.calls <- req + + select { + case err := <-r.errors: + return err + case p := <-r.panics: + panic(p) + default: + return nil + } +} + +func (r *testReconciler) failNext(err error) { r.errors <- err } +func (r *testReconciler) panicNext(p any) { r.panics <- p } + +func (r *testReconciler) expectNoRequest(t *testing.T, duration time.Duration) { + t.Helper() + + started := time.Now() + select { + case req := <-r.calls: + t.Fatalf("expected no request for %s, but got: %s after %s", duration, req.ID, time.Since(started)) + case <-time.After(duration): + } +} + +func (r *testReconciler) wait(t *testing.T) controller.Request { + t.Helper() + + var req controller.Request + select { + case req = <-r.calls: + case <-time.After(500 * time.Millisecond): + t.Fatal("Reconcile was not called after 500ms") + } + return req +} + +func testContext(t *testing.T) context.Context { + t.Helper() + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + return ctx +} diff --git a/internal/controller/controller.go b/internal/controller/controller.go index 11933b39fa..d99ca26f0d 100644 --- a/internal/controller/controller.go +++ b/internal/controller/controller.go @@ -1,15 +1,29 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + package controller import ( "context" + "errors" + "fmt" + "time" "github.com/hashicorp/go-hclog" + "golang.org/x/sync/errgroup" + "google.golang.org/protobuf/proto" + + "github.com/hashicorp/consul/agent/consul/controller/queue" + "github.com/hashicorp/consul/internal/resource" + "github.com/hashicorp/consul/internal/storage" + "github.com/hashicorp/consul/proto-public/pbresource" ) // controllerRunner contains the actual implementation of running a controller // including creating watches, calling the reconciler, handling retries, etc. type controllerRunner struct { ctrl Controller + client pbresource.ResourceServiceClient logger hclog.Logger } @@ -17,6 +31,155 @@ func (c *controllerRunner) run(ctx context.Context) error { c.logger.Debug("controller running") defer c.logger.Debug("controller stopping") - <-ctx.Done() - return ctx.Err() + group, groupCtx := errgroup.WithContext(ctx) + recQueue := runQueue[Request](groupCtx, c.ctrl) + + // Managed Type Events → Reconciliation Queue + group.Go(func() error { + return c.watch(groupCtx, c.ctrl.managedType, func(res *pbresource.Resource) { + recQueue.Add(Request{ID: res.Id}) + }) + }) + + for _, watch := range c.ctrl.watches { + watch := watch + mapQueue := runQueue[*pbresource.Resource](groupCtx, c.ctrl) + + // Watched Type Events → Mapper Queue + group.Go(func() error { + return c.watch(groupCtx, watch.watchedType, mapQueue.Add) + }) + + // Mapper Queue → Mapper → Reconciliation Queue + group.Go(func() error { + return c.runMapper(groupCtx, watch, mapQueue, recQueue) + }) + } + + // Reconciliation Queue → Reconciler + group.Go(func() error { + return c.runReconciler(groupCtx, recQueue) + }) + + return group.Wait() +} + +func runQueue[T queue.ItemType](ctx context.Context, ctrl Controller) queue.WorkQueue[T] { + base, max := ctrl.backoff() + return queue.RunWorkQueue[T](ctx, base, max) +} + +func (c *controllerRunner) watch(ctx context.Context, typ *pbresource.Type, add func(*pbresource.Resource)) error { + watch, err := c.client.WatchList(ctx, &pbresource.WatchListRequest{ + Type: typ, + Tenancy: &pbresource.Tenancy{ + Partition: storage.Wildcard, + PeerName: storage.Wildcard, + Namespace: storage.Wildcard, + }, + }) + if err != nil { + c.logger.Error("failed to create watch", "error", err) + return err + } + + for { + event, err := watch.Recv() + if err != nil { + c.logger.Warn("error received from watch", "error", err) + return err + } + add(event.Resource) + } +} + +func (c *controllerRunner) runMapper( + ctx context.Context, + w watch, + from queue.WorkQueue[*pbresource.Resource], + to queue.WorkQueue[Request], +) error { + logger := c.logger.With("watched_resource_type", resource.ToGVK(w.watchedType)) + + for { + res, shutdown := from.Get() + if shutdown { + return nil + } + + var reqs []Request + err := c.handlePanic(func() error { + var err error + reqs, err = w.mapper(ctx, c.runtime(), res) + return err + }) + if err != nil { + from.AddRateLimited(res) + from.Done(res) + continue + } + + for _, r := range reqs { + if !proto.Equal(r.ID.Type, c.ctrl.managedType) { + logger.Error("dependency mapper returned request for a resource of the wrong type", + "type_expected", resource.ToGVK(c.ctrl.managedType), + "type_got", resource.ToGVK(r.ID.Type), + ) + continue + } + to.Add(r) + } + + from.Forget(res) + from.Done(res) + } +} + +func (c *controllerRunner) runReconciler(ctx context.Context, queue queue.WorkQueue[Request]) error { + for { + req, shutdown := queue.Get() + if shutdown { + return nil + } + + c.logger.Trace("handling request", "request", req) + err := c.handlePanic(func() error { + return c.ctrl.reconciler.Reconcile(ctx, c.runtime(), req) + }) + if err == nil { + queue.Forget(req) + } else { + var requeueAfter RequeueAfterError + if errors.As(err, &requeueAfter) { + queue.Forget(req) + queue.AddAfter(req, time.Duration(requeueAfter)) + } else { + queue.AddRateLimited(req) + } + } + queue.Done(req) + } +} + +func (c *controllerRunner) handlePanic(fn func() error) (err error) { + defer func() { + if r := recover(); r != nil { + stack := hclog.Stacktrace() + c.logger.Error("controller panic", + "panic", r, + "stack", stack, + ) + err = fmt.Errorf("panic [recovered]: %v", r) + return + } + }() + + return fn() +} + +func (c *controllerRunner) runtime() Runtime { + return Runtime{ + Client: c.client, + Logger: c.logger, + } } diff --git a/internal/controller/doc.go b/internal/controller/doc.go new file mode 100644 index 0000000000..2895379152 --- /dev/null +++ b/internal/controller/doc.go @@ -0,0 +1,10 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +// Package controller provides an API for implementing control loops on top of +// Consul resources. It is heavily inspired by [Kubebuilder] and the Kubernetes +// [controller runtime]. +// +// [Kubebuilder]: https://github.com/kubernetes-sigs/kubebuilder +// [controller runtime]: https://github.com/kubernetes-sigs/controller-runtime +package controller diff --git a/internal/controller/lease.go b/internal/controller/lease.go index 33e284a69c..2cb00d1330 100644 --- a/internal/controller/lease.go +++ b/internal/controller/lease.go @@ -22,3 +22,8 @@ type raftLease struct { func (l *raftLease) Held() bool { return l.m.raftLeader.Load() } func (l *raftLease) Changed() <-chan struct{} { return l.ch } + +type eternalLease struct{} + +func (eternalLease) Held() bool { return true } +func (eternalLease) Changed() <-chan struct{} { return nil } diff --git a/internal/controller/manager.go b/internal/controller/manager.go index 90b9f2994b..92c5829c58 100644 --- a/internal/controller/manager.go +++ b/internal/controller/manager.go @@ -2,16 +2,19 @@ package controller import ( "context" + "fmt" "sync" "sync/atomic" "github.com/hashicorp/go-hclog" "github.com/hashicorp/consul/internal/resource" + "github.com/hashicorp/consul/proto-public/pbresource" ) // Manager is responsible for scheduling the execution of controllers. type Manager struct { + client pbresource.ResourceServiceClient logger hclog.Logger raftLeader atomic.Bool @@ -24,8 +27,11 @@ type Manager struct { // NewManager creates a Manager. logger will be used by the Manager, and as the // base logger for controllers when one is not specified using WithLogger. -func NewManager(logger hclog.Logger) *Manager { - return &Manager{logger: logger} +func NewManager(client pbresource.ResourceServiceClient, logger hclog.Logger) *Manager { + return &Manager{ + client: client, + logger: logger, + } } // Register the given controller to be executed by the Manager. Cannot be called @@ -38,6 +44,10 @@ func (m *Manager) Register(ctrl Controller) { panic("cannot register additional controllers after calling Run") } + if ctrl.reconciler == nil { + panic(fmt.Sprintf("cannot register controller without a reconciler %s", ctrl)) + } + m.controllers = append(m.controllers, ctrl) } @@ -53,11 +63,17 @@ func (m *Manager) Run(ctx context.Context) { m.running = true for _, desc := range m.controllers { + logger := desc.logger + if logger == nil { + logger = m.logger.With("managed_type", resource.ToGVK(desc.managedType)) + } + runner := &controllerRunner{ ctrl: desc, - logger: m.logger.With("managed_type", resource.ToGVK(desc.managedType)), + client: m.client, + logger: logger, } - go newSupervisor(runner.run, m.newLeaseLocked()).run(ctx) + go newSupervisor(runner.run, m.newLeaseLocked(desc)).run(ctx) } } @@ -82,7 +98,11 @@ func (m *Manager) SetRaftLeader(leader bool) { } } -func (m *Manager) newLeaseLocked() Lease { +func (m *Manager) newLeaseLocked(ctrl Controller) Lease { + if ctrl.placement == PlacementEachServer { + return eternalLease{} + } + ch := make(chan struct{}, 1) m.leaseChans = append(m.leaseChans, ch) return &raftLease{m: m, ch: ch} diff --git a/internal/resource/demo/controller.go b/internal/resource/demo/controller.go index f2172f0f85..11de1c5057 100644 --- a/internal/resource/demo/controller.go +++ b/internal/resource/demo/controller.go @@ -1,6 +1,25 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + package demo -import "github.com/hashicorp/consul/internal/controller" +import ( + "context" + "fmt" + "math/rand" + + "github.com/oklog/ulid/v2" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + + "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/resource" + "github.com/hashicorp/consul/proto-public/pbresource" + pbdemov2 "github.com/hashicorp/consul/proto/private/pbdemo/v2" +) + +const statusKeyArtistController = "consul.io/artist-controller" // RegisterControllers registers controllers for the demo types. Should only be // called in dev mode. @@ -9,5 +28,158 @@ func RegisterControllers(mgr *controller.Manager) { } func artistController() controller.Controller { - return controller.ForType(TypeV2Artist) + return controller.ForType(TypeV2Artist). + WithWatch(TypeV2Album, controller.MapOwner). + WithReconciler(&artistReconciler{}) +} + +type artistReconciler struct{} + +func (r *artistReconciler) Reconcile(ctx context.Context, rt controller.Runtime, req controller.Request) error { + rsp, err := rt.Client.Read(ctx, &pbresource.ReadRequest{Id: req.ID}) + switch { + case status.Code(err) == codes.NotFound: + return nil + case err != nil: + return err + } + res := rsp.Resource + + var artist pbdemov2.Artist + if err := res.Data.UnmarshalTo(&artist); err != nil { + return err + } + conditions := []*pbresource.Condition{ + { + Type: "Accepted", + State: pbresource.Condition_STATE_TRUE, + Reason: "Accepted", + Message: fmt.Sprintf("Artist '%s' accepted", artist.Name), + }, + } + + numAlbums := 3 + if artist.Genre == pbdemov2.Genre_GENRE_BLUES { + numAlbums = 10 + } + + desiredAlbums, err := generateV2AlbumsDeterministic(res.Id, numAlbums) + if err != nil { + return err + } + + actualAlbums, err := rt.Client.List(ctx, &pbresource.ListRequest{ + Type: TypeV2Album, + Tenancy: res.Id.Tenancy, + NamePrefix: fmt.Sprintf("%s/", res.Id.Name), + }) + if err != nil { + return err + } + + writes, deletions, err := diffAlbums(desiredAlbums, actualAlbums.Resources) + if err != nil { + return err + } + for _, w := range writes { + if _, err := rt.Client.Write(ctx, &pbresource.WriteRequest{Resource: w}); err != nil { + return err + } + } + for _, d := range deletions { + if _, err := rt.Client.Delete(ctx, &pbresource.DeleteRequest{Id: d}); err != nil { + return err + } + } + + for _, want := range desiredAlbums { + var album pbdemov2.Album + if err := want.Data.UnmarshalTo(&album); err != nil { + return err + } + conditions = append(conditions, &pbresource.Condition{ + Type: "AlbumCreated", + State: pbresource.Condition_STATE_TRUE, + Reason: "AlbumCreated", + Message: fmt.Sprintf("Album '%s' created for artist '%s'", album.Title, artist.Name), + Resource: resource.Reference(want.Id, ""), + }) + } + + newStatus := &pbresource.Status{ + ObservedGeneration: res.Generation, + Conditions: conditions, + } + + if proto.Equal(res.Status[statusKeyArtistController], newStatus) { + return nil + } + + _, err = rt.Client.WriteStatus(ctx, &pbresource.WriteStatusRequest{ + Id: res.Id, + Key: statusKeyArtistController, + Status: newStatus, + }) + return err +} + +func diffAlbums(want, have []*pbresource.Resource) ([]*pbresource.Resource, []*pbresource.ID, error) { + haveMap := make(map[string]*pbresource.Resource, len(have)) + for _, r := range have { + haveMap[r.Id.Name] = r + } + + wantMap := make(map[string]struct{}, len(want)) + for _, r := range want { + wantMap[r.Id.Name] = struct{}{} + } + + writes := make([]*pbresource.Resource, 0) + for _, w := range want { + h, ok := haveMap[w.Id.Name] + if ok { + var wd, hd pbdemov2.Album + if err := w.Data.UnmarshalTo(&wd); err != nil { + return nil, nil, err + } + if err := h.Data.UnmarshalTo(&hd); err != nil { + return nil, nil, err + } + if proto.Equal(&wd, &hd) { + continue + } + } + + writes = append(writes, w) + } + + deletions := make([]*pbresource.ID, 0) + for _, h := range have { + if _, ok := wantMap[h.Id.Name]; ok { + continue + } + deletions = append(deletions, h.Id) + } + + return writes, deletions, nil +} + +func generateV2AlbumsDeterministic(artistID *pbresource.ID, count int) ([]*pbresource.Resource, error) { + uid, err := ulid.Parse(artistID.Uid) + if err != nil { + return nil, fmt.Errorf("failed to parse Uid: %w", err) + } + rand := rand.New(rand.NewSource(int64(uid.Time()))) + + albums := make([]*pbresource.Resource, count) + for i := 0; i < count; i++ { + album, err := generateV2Album(artistID, rand) + if err != nil { + return nil, err + } + // Add suffix to avoid collisions. + album.Id.Name = fmt.Sprintf("%s-%d", album.Id.Name, i) + albums[i] = album + } + return albums, nil } diff --git a/internal/resource/demo/controller_test.go b/internal/resource/demo/controller_test.go new file mode 100644 index 0000000000..8d4ee79c73 --- /dev/null +++ b/internal/resource/demo/controller_test.go @@ -0,0 +1,102 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package demo + +import ( + "testing" + + "github.com/stretchr/testify/require" + + svctest "github.com/hashicorp/consul/agent/grpc-external/services/resource/testing" + "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/proto-public/pbresource" + pbdemov2 "github.com/hashicorp/consul/proto/private/pbdemo/v2" + "github.com/hashicorp/consul/sdk/testutil" +) + +func TestArtistReconciler(t *testing.T) { + client := svctest.RunResourceService(t, RegisterTypes) + + // Seed the database with an artist. + res, err := GenerateV2Artist() + require.NoError(t, err) + + // Set the genre to BLUES to ensure there are 10 albums. + var artist pbdemov2.Artist + require.NoError(t, res.Data.UnmarshalTo(&artist)) + artist.Genre = pbdemov2.Genre_GENRE_BLUES + require.NoError(t, res.Data.MarshalFrom(&artist)) + + ctx := testutil.TestContext(t) + writeRsp, err := client.Write(ctx, &pbresource.WriteRequest{Resource: res}) + require.NoError(t, err) + + // Call the reconciler for that artist. + var rec artistReconciler + runtime := controller.Runtime{ + Client: client, + Logger: testutil.Logger(t), + } + req := controller.Request{ + ID: writeRsp.Resource.Id, + } + require.NoError(t, rec.Reconcile(ctx, runtime, req)) + + // Check the status was updated. + readRsp, err := client.Read(ctx, &pbresource.ReadRequest{Id: writeRsp.Resource.Id}) + require.NoError(t, err) + require.Contains(t, readRsp.Resource.Status, "consul.io/artist-controller") + + status := readRsp.Resource.Status["consul.io/artist-controller"] + require.Equal(t, writeRsp.Resource.Generation, status.ObservedGeneration) + require.Len(t, status.Conditions, 11) + require.Equal(t, "Accepted", status.Conditions[0].Type) + require.Equal(t, "AlbumCreated", status.Conditions[1].Type) + + // Check the albums were created. + listRsp, err := client.List(ctx, &pbresource.ListRequest{ + Type: TypeV2Album, + Tenancy: readRsp.Resource.Id.Tenancy, + }) + require.NoError(t, err) + require.Len(t, listRsp.Resources, 10) + + // Delete an album. + _, err = client.Delete(ctx, &pbresource.DeleteRequest{Id: listRsp.Resources[0].Id}) + require.NoError(t, err) + + // Call the reconciler again. + require.NoError(t, rec.Reconcile(ctx, runtime, req)) + + // Check the album was recreated. + listRsp, err = client.List(ctx, &pbresource.ListRequest{ + Type: TypeV2Album, + Tenancy: readRsp.Resource.Id.Tenancy, + }) + require.NoError(t, err) + require.Len(t, listRsp.Resources, 10) + + // Set the genre to DISCO. + readRsp, err = client.Read(ctx, &pbresource.ReadRequest{Id: writeRsp.Resource.Id}) + require.NoError(t, err) + + res = readRsp.Resource + require.NoError(t, res.Data.UnmarshalTo(&artist)) + artist.Genre = pbdemov2.Genre_GENRE_DISCO + require.NoError(t, res.Data.MarshalFrom(&artist)) + + _, err = client.Write(ctx, &pbresource.WriteRequest{Resource: res}) + require.NoError(t, err) + + // Call the reconciler again. + require.NoError(t, rec.Reconcile(ctx, runtime, req)) + + // Check there are only 3 albums now. + listRsp, err = client.List(ctx, &pbresource.ListRequest{ + Type: TypeV2Album, + Tenancy: readRsp.Resource.Id.Tenancy, + }) + require.NoError(t, err) + require.Len(t, listRsp.Resources, 3) +} diff --git a/internal/resource/demo/demo.go b/internal/resource/demo/demo.go index fde9272aeb..842b75739b 100644 --- a/internal/resource/demo/demo.go +++ b/internal/resource/demo/demo.go @@ -204,6 +204,10 @@ func GenerateV2Artist() (*pbresource.Resource, error) { // GenerateV2Album generates a random Album resource, owned by the Artist with // the given ID. func GenerateV2Album(artistID *pbresource.ID) (*pbresource.Resource, error) { + return generateV2Album(artistID, rand.New(rand.NewSource(time.Now().UnixNano()))) +} + +func generateV2Album(artistID *pbresource.ID, rand *rand.Rand) (*pbresource.Resource, error) { adjective := adjectives[rand.Intn(len(adjectives))] noun := nouns[rand.Intn(len(nouns))]