From 9c40aa729f63d7a2d4a306b568061bfb972ab903 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Wed, 23 Dec 2020 17:12:36 -0500 Subject: [PATCH] proxycfg: pass context around where it is needed context.Context should never be stored on a struct (as it says in the godoc) because it is easy to to end up with the wrong context when it is stored. Also see https://blog.golang.org/context-and-structs This change is also in preparation for splitting state into kind-specific handlers so that the implementation of each kind is grouped together. --- agent/proxycfg/state.go | 140 ++++++++++++++++++----------------- agent/proxycfg/state_test.go | 9 ++- 2 files changed, 76 insertions(+), 73 deletions(-) diff --git a/agent/proxycfg/state.go b/agent/proxycfg/state.go index 63269378fa..e5c94526be 100644 --- a/agent/proxycfg/state.go +++ b/agent/proxycfg/state.go @@ -67,8 +67,8 @@ type state struct { serverSNIFn ServerSNIFunc intentionDefaultAllow bool - // ctx and cancel store the context created during initWatches call - ctx context.Context + // cancel is set by Watch and called by Close to stop the goroutine started + // in Watch. cancel func() kind structs.ServiceKind @@ -183,16 +183,17 @@ func newState(ns *structs.NodeService, token string) (*state, error) { // ConfigSnapshot that contains all necessary config state. The chan is closed // when the state is Closed. func (s *state) Watch() (<-chan ConfigSnapshot, error) { - s.ctx, s.cancel = context.WithCancel(context.Background()) + var ctx context.Context + ctx, s.cancel = context.WithCancel(context.Background()) snap := s.initialConfigSnapshot() - err := s.initWatches(&snap) + err := s.initWatches(ctx, &snap) if err != nil { s.cancel() return nil, err } - go s.run(&snap) + go s.run(ctx, &snap) return s.snapCh, nil } @@ -206,16 +207,16 @@ func (s *state) Close() error { } // initWatches sets up the watches needed for the particular service -func (s *state) initWatches(snap *ConfigSnapshot) error { +func (s *state) initWatches(ctx context.Context, snap *ConfigSnapshot) error { switch s.kind { case structs.ServiceKindConnectProxy: - return s.initWatchesConnectProxy(snap) + return s.initWatchesConnectProxy(ctx, snap) case structs.ServiceKindTerminatingGateway: - return s.initWatchesTerminatingGateway() + return s.initWatchesTerminatingGateway(ctx) case structs.ServiceKindMeshGateway: - return s.initWatchesMeshGateway() + return s.initWatchesMeshGateway(ctx) case structs.ServiceKindIngressGateway: - return s.initWatchesIngressGateway() + return s.initWatchesIngressGateway(ctx) default: return fmt.Errorf("Unsupported service kind") } @@ -234,9 +235,9 @@ func (s *state) watchMeshGateway(ctx context.Context, dc string, upstreamID stri // initWatchesConnectProxy sets up the watches needed based on current proxy registration // state. -func (s *state) initWatchesConnectProxy(snap *ConfigSnapshot) error { +func (s *state) initWatchesConnectProxy(ctx context.Context, snap *ConfigSnapshot) error { // Watch for root changes - err := s.cache.Notify(s.ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ + err := s.cache.Notify(ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, Source: *s.source, @@ -246,7 +247,7 @@ func (s *state) initWatchesConnectProxy(snap *ConfigSnapshot) error { } // Watch the leaf cert - err = s.cache.Notify(s.ctx, cachetype.ConnectCALeafName, &cachetype.ConnectCALeafRequest{ + err = s.cache.Notify(ctx, cachetype.ConnectCALeafName, &cachetype.ConnectCALeafRequest{ Datacenter: s.source.Datacenter, Token: s.token, Service: s.proxyCfg.DestinationServiceName, @@ -257,7 +258,7 @@ func (s *state) initWatchesConnectProxy(snap *ConfigSnapshot) error { } // Watch for intention updates - err = s.cache.Notify(s.ctx, cachetype.IntentionMatchName, &structs.IntentionQueryRequest{ + err = s.cache.Notify(ctx, cachetype.IntentionMatchName, &structs.IntentionQueryRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, Match: &structs.IntentionQueryMatch{ @@ -275,7 +276,7 @@ func (s *state) initWatchesConnectProxy(snap *ConfigSnapshot) error { } // Watch for service check updates - err = s.cache.Notify(s.ctx, cachetype.ServiceHTTPChecksName, &cachetype.ServiceHTTPChecksRequest{ + err = s.cache.Notify(ctx, cachetype.ServiceHTTPChecksName, &cachetype.ServiceHTTPChecksRequest{ ServiceID: s.proxyCfg.DestinationServiceID, EnterpriseMeta: s.proxyID.EnterpriseMeta, }, svcChecksWatchIDPrefix+structs.ServiceIDString(s.proxyCfg.DestinationServiceID, &s.proxyID.EnterpriseMeta), s.ch) @@ -288,7 +289,7 @@ func (s *state) initWatchesConnectProxy(snap *ConfigSnapshot) error { if s.proxyCfg.Mode == structs.ProxyModeTransparent { // When in transparent proxy we will infer upstreams from intentions with this source - err := s.cache.Notify(s.ctx, cachetype.IntentionUpstreamsName, &structs.ServiceSpecificRequest{ + err := s.cache.Notify(ctx, cachetype.IntentionUpstreamsName, &structs.ServiceSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, ServiceName: s.proxyCfg.DestinationServiceName, @@ -298,7 +299,7 @@ func (s *state) initWatchesConnectProxy(snap *ConfigSnapshot) error { return err } - err = s.cache.Notify(s.ctx, cachetype.ConfigEntryName, &structs.ConfigEntryQuery{ + err = s.cache.Notify(ctx, cachetype.ConfigEntryName, &structs.ConfigEntryQuery{ Kind: structs.MeshConfig, Name: structs.MeshConfigMesh, Datacenter: s.source.Datacenter, @@ -354,7 +355,7 @@ func (s *state) initWatchesConnectProxy(snap *ConfigSnapshot) error { switch u.DestinationType { case structs.UpstreamDestTypePreparedQuery: - err = s.cache.Notify(s.ctx, cachetype.PreparedQueryName, &structs.PreparedQueryExecuteRequest{ + err = s.cache.Notify(ctx, cachetype.PreparedQueryName, &structs.PreparedQueryExecuteRequest{ Datacenter: dc, QueryOptions: structs.QueryOptions{Token: s.token, MaxAge: defaultPreparedQueryPollInterval}, QueryIDOrName: u.DestinationName, @@ -369,7 +370,7 @@ func (s *state) initWatchesConnectProxy(snap *ConfigSnapshot) error { fallthrough case "": // Treat unset as the default Service type - err = s.cache.Notify(s.ctx, cachetype.CompiledDiscoveryChainName, &structs.DiscoveryChainRequest{ + err = s.cache.Notify(ctx, cachetype.CompiledDiscoveryChainName, &structs.DiscoveryChainRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, Name: u.DestinationName, @@ -411,9 +412,9 @@ func parseReducedUpstreamConfig(m map[string]interface{}) (reducedUpstreamConfig } // initWatchesTerminatingGateway sets up the initial watches needed based on the terminating-gateway registration -func (s *state) initWatchesTerminatingGateway() error { +func (s *state) initWatchesTerminatingGateway(ctx context.Context) error { // Watch for root changes - err := s.cache.Notify(s.ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ + err := s.cache.Notify(ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, Source: *s.source, @@ -425,7 +426,7 @@ func (s *state) initWatchesTerminatingGateway() error { } // Watch for the terminating-gateway's linked services - err = s.cache.Notify(s.ctx, cachetype.GatewayServicesName, &structs.ServiceSpecificRequest{ + err = s.cache.Notify(ctx, cachetype.GatewayServicesName, &structs.ServiceSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, ServiceName: s.service, @@ -441,9 +442,9 @@ func (s *state) initWatchesTerminatingGateway() error { } // initWatchesMeshGateway sets up the watches needed based on the current mesh gateway registration -func (s *state) initWatchesMeshGateway() error { +func (s *state) initWatchesMeshGateway(ctx context.Context) error { // Watch for root changes - err := s.cache.Notify(s.ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ + err := s.cache.Notify(ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, Source: *s.source, @@ -453,7 +454,7 @@ func (s *state) initWatchesMeshGateway() error { } // Watch for all services - err = s.cache.Notify(s.ctx, cachetype.CatalogServiceListName, &structs.DCSpecificRequest{ + err = s.cache.Notify(ctx, cachetype.CatalogServiceListName, &structs.DCSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, Source: *s.source, @@ -468,7 +469,7 @@ func (s *state) initWatchesMeshGateway() error { // Conveniently we can just use this service meta attribute in one // place here to set the machinery in motion and leave the conditional // behavior out of the rest of the package. - err = s.cache.Notify(s.ctx, cachetype.FederationStateListMeshGatewaysName, &structs.DCSpecificRequest{ + err = s.cache.Notify(ctx, cachetype.FederationStateListMeshGatewaysName, &structs.DCSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, Source: *s.source, @@ -477,7 +478,7 @@ func (s *state) initWatchesMeshGateway() error { return err } - err = s.health.Notify(s.ctx, structs.ServiceSpecificRequest{ + err = s.health.Notify(ctx, structs.ServiceSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, ServiceName: structs.ConsulServiceName, @@ -492,7 +493,7 @@ func (s *state) initWatchesMeshGateway() error { // cannot setup those watches until we know what the services are. from the service list // watch above - err = s.cache.Notify(s.ctx, cachetype.CatalogDatacentersName, &structs.DatacentersRequest{ + err = s.cache.Notify(ctx, cachetype.CatalogDatacentersName, &structs.DatacentersRequest{ QueryOptions: structs.QueryOptions{Token: s.token, MaxAge: 30 * time.Second}, }, datacentersWatchID, s.ch) if err != nil { @@ -504,7 +505,7 @@ func (s *state) initWatchesMeshGateway() error { // know what they are yet. // Watch service-resolvers so we can setup service subset clusters - err = s.cache.Notify(s.ctx, cachetype.ConfigEntriesName, &structs.ConfigEntryQuery{ + err = s.cache.Notify(ctx, cachetype.ConfigEntriesName, &structs.ConfigEntryQuery{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, Kind: structs.ServiceResolver, @@ -520,9 +521,9 @@ func (s *state) initWatchesMeshGateway() error { return err } -func (s *state) initWatchesIngressGateway() error { +func (s *state) initWatchesIngressGateway(ctx context.Context) error { // Watch for root changes - err := s.cache.Notify(s.ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ + err := s.cache.Notify(ctx, cachetype.ConnectCARootName, &structs.DCSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, Source: *s.source, @@ -532,7 +533,7 @@ func (s *state) initWatchesIngressGateway() error { } // Watch this ingress gateway's config entry - err = s.cache.Notify(s.ctx, cachetype.ConfigEntryName, &structs.ConfigEntryQuery{ + err = s.cache.Notify(ctx, cachetype.ConfigEntryName, &structs.ConfigEntryQuery{ Kind: structs.IngressGateway, Name: s.service, Datacenter: s.source.Datacenter, @@ -544,7 +545,7 @@ func (s *state) initWatchesIngressGateway() error { } // Watch the ingress-gateway's list of upstreams - err = s.cache.Notify(s.ctx, cachetype.GatewayServicesName, &structs.ServiceSpecificRequest{ + err = s.cache.Notify(ctx, cachetype.GatewayServicesName, &structs.ServiceSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, ServiceName: s.service, @@ -619,7 +620,7 @@ func (s *state) initialConfigSnapshot() ConfigSnapshot { return snap } -func (s *state) run(snap *ConfigSnapshot) { +func (s *state) run(ctx context.Context, snap *ConfigSnapshot) { // Close the channel we return from Watch when we stop so consumers can stop // watching and clean up their goroutines. It's important we do this here and // not in Close since this routine sends on this chan and so might panic if it @@ -635,12 +636,12 @@ func (s *state) run(snap *ConfigSnapshot) { for { select { - case <-s.ctx.Done(): + case <-ctx.Done(): return case u := <-s.ch: s.logger.Trace("A blocking query returned; handling snapshot update") - if err := s.handleUpdate(u, snap); err != nil { + if err := s.handleUpdate(ctx, u, snap); err != nil { s.logger.Error("Failed to handle update from watch", "id", u.CorrelationID, "error", err, ) @@ -729,22 +730,22 @@ func (s *state) run(snap *ConfigSnapshot) { } } -func (s *state) handleUpdate(u cache.UpdateEvent, snap *ConfigSnapshot) error { +func (s *state) handleUpdate(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error { switch s.kind { case structs.ServiceKindConnectProxy: - return s.handleUpdateConnectProxy(u, snap) + return s.handleUpdateConnectProxy(ctx, u, snap) case structs.ServiceKindTerminatingGateway: - return s.handleUpdateTerminatingGateway(u, snap) + return s.handleUpdateTerminatingGateway(ctx, u, snap) case structs.ServiceKindMeshGateway: - return s.handleUpdateMeshGateway(u, snap) + return s.handleUpdateMeshGateway(ctx, u, snap) case structs.ServiceKindIngressGateway: - return s.handleUpdateIngressGateway(u, snap) + return s.handleUpdateIngressGateway(ctx, u, snap) default: return fmt.Errorf("Unsupported service kind") } } -func (s *state) handleUpdateConnectProxy(u cache.UpdateEvent, snap *ConfigSnapshot) error { +func (s *state) handleUpdateConnectProxy(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error { if u.Err != nil { return fmt.Errorf("error filling agent cache: %v", u.Err) } @@ -818,7 +819,7 @@ func (s *state) handleUpdateConnectProxy(u cache.UpdateEvent, snap *ConfigSnapsh cfg: cfg, meshGateway: meshGateway, } - err = s.watchDiscoveryChain(snap, watchOpts) + err = s.watchDiscoveryChain(ctx, snap, watchOpts) if err != nil { return fmt.Errorf("failed to watch discovery chain for %s: %v", svc.String(), err) } @@ -887,12 +888,12 @@ func (s *state) handleUpdateConnectProxy(u cache.UpdateEvent, snap *ConfigSnapsh snap.ConnectProxy.MeshConfigSet = true default: - return s.handleUpdateUpstreams(u, snap) + return s.handleUpdateUpstreams(ctx, u, snap) } return nil } -func (s *state) handleUpdateUpstreams(u cache.UpdateEvent, snap *ConfigSnapshot) error { +func (s *state) handleUpdateUpstreams(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error { if u.Err != nil { return fmt.Errorf("error filling agent cache: %v", u.Err) } @@ -918,7 +919,7 @@ func (s *state) handleUpdateUpstreams(u cache.UpdateEvent, snap *ConfigSnapshot) svc := strings.TrimPrefix(u.CorrelationID, "discovery-chain:") upstreamsSnapshot.DiscoveryChain[svc] = resp.Chain - if err := s.resetWatchesFromChain(svc, resp.Chain, upstreamsSnapshot); err != nil { + if err := s.resetWatchesFromChain(ctx, svc, resp.Chain, upstreamsSnapshot); err != nil { return err } @@ -1010,6 +1011,7 @@ func removeColonPrefix(s string) (string, string, bool) { } func (s *state) resetWatchesFromChain( + ctx context.Context, id string, chain *structs.CompiledDiscoveryChain, snap *ConfigSnapshotUpstreams, @@ -1068,7 +1070,7 @@ func (s *state) resetWatchesFromChain( datacenter: target.Datacenter, entMeta: target.GetEnterpriseMetadata(), } - err := s.watchUpstreamTarget(snap, opts) + err := s.watchUpstreamTarget(ctx, snap, opts) if err != nil { return fmt.Errorf("failed to watch target %q for upstream %q", target.ID, id) } @@ -1102,7 +1104,7 @@ func (s *state) resetWatchesFromChain( datacenter: chain.Datacenter, entMeta: &chainEntMeta, } - err := s.watchUpstreamTarget(snap, opts) + err := s.watchUpstreamTarget(ctx, snap, opts) if err != nil { return fmt.Errorf("failed to watch target %q for upstream %q", chainID, id) } @@ -1119,7 +1121,7 @@ func (s *state) resetWatchesFromChain( "datacenter", dc, ) - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.watchMeshGateway(ctx, dc, id) if err != nil { cancel() @@ -1155,7 +1157,7 @@ type targetWatchOpts struct { entMeta *structs.EnterpriseMeta } -func (s *state) watchUpstreamTarget(snap *ConfigSnapshotUpstreams, opts targetWatchOpts) error { +func (s *state) watchUpstreamTarget(ctx context.Context, snap *ConfigSnapshotUpstreams, opts targetWatchOpts) error { s.logger.Trace("initializing watch of target", "upstream", opts.upstreamID, "chain", opts.service, @@ -1167,7 +1169,7 @@ func (s *state) watchUpstreamTarget(snap *ConfigSnapshotUpstreams, opts targetWa correlationID := "upstream-target:" + opts.chainID + ":" + opts.upstreamID - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.health.Notify(ctx, structs.ServiceSpecificRequest{ Datacenter: opts.datacenter, QueryOptions: structs.QueryOptions{ @@ -1192,7 +1194,7 @@ func (s *state) watchUpstreamTarget(snap *ConfigSnapshotUpstreams, opts targetWa return nil } -func (s *state) handleUpdateTerminatingGateway(u cache.UpdateEvent, snap *ConfigSnapshot) error { +func (s *state) handleUpdateTerminatingGateway(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error { if u.Err != nil { return fmt.Errorf("error filling agent cache: %v", u.Err) } @@ -1223,7 +1225,7 @@ func (s *state) handleUpdateTerminatingGateway(u cache.UpdateEvent, snap *Config // Watch the health endpoint to discover endpoints for the service if _, ok := snap.TerminatingGateway.WatchedServices[svc.Service]; !ok { - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.health.Notify(ctx, structs.ServiceSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, @@ -1248,7 +1250,7 @@ func (s *state) handleUpdateTerminatingGateway(u cache.UpdateEvent, snap *Config // Watch intentions with this service as their destination // The gateway will enforce intentions for connections to the service if _, ok := snap.TerminatingGateway.WatchedIntentions[svc.Service]; !ok { - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.cache.Notify(ctx, cachetype.IntentionMatchName, &structs.IntentionQueryRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, @@ -1277,7 +1279,7 @@ func (s *state) handleUpdateTerminatingGateway(u cache.UpdateEvent, snap *Config // Watch leaf certificate for the service // This cert is used to terminate mTLS connections on the service's behalf if _, ok := snap.TerminatingGateway.WatchedLeaves[svc.Service]; !ok { - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.cache.Notify(ctx, cachetype.ConnectCALeafName, &cachetype.ConnectCALeafRequest{ Datacenter: s.source.Datacenter, Token: s.token, @@ -1299,7 +1301,7 @@ func (s *state) handleUpdateTerminatingGateway(u cache.UpdateEvent, snap *Config // Watch service configs for the service. // These are used to determine the protocol for the target service. if _, ok := snap.TerminatingGateway.WatchedConfigs[svc.Service]; !ok { - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.cache.Notify(ctx, cachetype.ResolvedServiceConfigName, &structs.ServiceConfigRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, @@ -1321,7 +1323,7 @@ func (s *state) handleUpdateTerminatingGateway(u cache.UpdateEvent, snap *Config // Watch service resolvers for the service // These are used to create clusters and endpoints for the service subsets if _, ok := snap.TerminatingGateway.WatchedResolvers[svc.Service]; !ok { - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.cache.Notify(ctx, cachetype.ConfigEntriesName, &structs.ConfigEntryQuery{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, @@ -1477,7 +1479,7 @@ func (s *state) handleUpdateTerminatingGateway(u cache.UpdateEvent, snap *Config return nil } -func (s *state) handleUpdateMeshGateway(u cache.UpdateEvent, snap *ConfigSnapshot) error { +func (s *state) handleUpdateMeshGateway(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error { if u.Err != nil { return fmt.Errorf("error filling agent cache: %v", u.Err) } @@ -1521,7 +1523,7 @@ func (s *state) handleUpdateMeshGateway(u cache.UpdateEvent, snap *ConfigSnapsho svcMap[svc] = struct{}{} if _, ok := snap.MeshGateway.WatchedServices[svc]; !ok { - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.health.Notify(ctx, structs.ServiceSpecificRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, @@ -1570,7 +1572,7 @@ func (s *state) handleUpdateMeshGateway(u cache.UpdateEvent, snap *ConfigSnapsho } if _, ok := snap.MeshGateway.WatchedDatacenters[dc]; !ok { - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.cache.Notify(ctx, cachetype.InternalServiceDumpName, &structs.ServiceDumpRequest{ Datacenter: dc, QueryOptions: structs.QueryOptions{Token: s.token}, @@ -1678,7 +1680,7 @@ func (s *state) handleUpdateMeshGateway(u cache.UpdateEvent, snap *ConfigSnapsho return nil } -func (s *state) handleUpdateIngressGateway(u cache.UpdateEvent, snap *ConfigSnapshot) error { +func (s *state) handleUpdateIngressGateway(ctx context.Context, u cache.UpdateEvent, snap *ConfigSnapshot) error { if u.Err != nil { return fmt.Errorf("error filling agent cache: %v", u.Err) } @@ -1703,7 +1705,7 @@ func (s *state) handleUpdateIngressGateway(u cache.UpdateEvent, snap *ConfigSnap snap.IngressGateway.TLSEnabled = gatewayConf.TLS.Enabled snap.IngressGateway.TLSSet = true - if err := s.watchIngressLeafCert(snap); err != nil { + if err := s.watchIngressLeafCert(ctx, snap); err != nil { return err } @@ -1725,7 +1727,7 @@ func (s *state) handleUpdateIngressGateway(u cache.UpdateEvent, snap *ConfigSnap name: u.DestinationName, namespace: u.DestinationNamespace, } - err := s.watchDiscoveryChain(snap, watchOpts) + err := s.watchDiscoveryChain(ctx, snap, watchOpts) if err != nil { return fmt.Errorf("failed to watch discovery chain for %s: %v", u.Identifier(), err) } @@ -1748,12 +1750,12 @@ func (s *state) handleUpdateIngressGateway(u cache.UpdateEvent, snap *ConfigSnap } } - if err := s.watchIngressLeafCert(snap); err != nil { + if err := s.watchIngressLeafCert(ctx, snap); err != nil { return err } default: - return s.handleUpdateUpstreams(u, snap) + return s.handleUpdateUpstreams(ctx, u, snap) } return nil @@ -1785,12 +1787,12 @@ type discoveryChainWatchOpts struct { meshGateway structs.MeshGatewayConfig } -func (s *state) watchDiscoveryChain(snap *ConfigSnapshot, opts discoveryChainWatchOpts) error { +func (s *state) watchDiscoveryChain(ctx context.Context, snap *ConfigSnapshot, opts discoveryChainWatchOpts) error { if _, ok := snap.ConnectProxy.WatchedDiscoveryChains[opts.id]; ok { return nil } - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.cache.Notify(ctx, cachetype.CompiledDiscoveryChainName, &structs.DiscoveryChainRequest{ Datacenter: s.source.Datacenter, QueryOptions: structs.QueryOptions{Token: s.token}, @@ -1856,7 +1858,7 @@ func (s *state) generateIngressDNSSANs(snap *ConfigSnapshot) []string { return dnsNames } -func (s *state) watchIngressLeafCert(snap *ConfigSnapshot) error { +func (s *state) watchIngressLeafCert(ctx context.Context, snap *ConfigSnapshot) error { if !snap.IngressGateway.TLSSet || !snap.IngressGateway.HostsSet { return nil } @@ -1865,7 +1867,7 @@ func (s *state) watchIngressLeafCert(snap *ConfigSnapshot) error { if snap.IngressGateway.LeafCertWatchCancel != nil { snap.IngressGateway.LeafCertWatchCancel() } - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(ctx) err := s.cache.Notify(ctx, cachetype.ConnectCALeafName, &cachetype.ConnectCALeafRequest{ Datacenter: s.source.Datacenter, Token: s.token, diff --git a/agent/proxycfg/state_test.go b/agent/proxycfg/state_test.go index 22e99477e5..ba3806d3b7 100644 --- a/agent/proxycfg/state_test.go +++ b/agent/proxycfg/state_test.go @@ -3,7 +3,6 @@ package proxycfg import ( "context" "fmt" - "github.com/hashicorp/consul/agent/connect" "sync" "testing" "time" @@ -12,6 +11,7 @@ import ( "github.com/hashicorp/consul/agent/cache" cachetype "github.com/hashicorp/consul/agent/cache-types" + "github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/consul/discoverychain" "github.com/hashicorp/consul/agent/rpcclient/health" "github.com/hashicorp/consul/agent/structs" @@ -1973,13 +1973,14 @@ func TestState_WatchesAndUpdates(t *testing.T) { } // setup the ctx as initWatches expects this to be there - state.ctx, state.cancel = context.WithCancel(context.Background()) + var ctx context.Context + ctx, state.cancel = context.WithCancel(context.Background()) // get the initial configuration snapshot snap := state.initialConfigSnapshot() // ensure the initial watch setup did not error - require.NoError(t, state.initWatches(&snap)) + require.NoError(t, state.initWatches(ctx, &snap)) //-------------------------------------------------------------------- // @@ -2006,7 +2007,7 @@ func TestState_WatchesAndUpdates(t *testing.T) { // therefore we just tell it about the updates for eveIdx, event := range stage.events { require.True(t, t.Run(fmt.Sprintf("update-%d", eveIdx), func(t *testing.T) { - require.NoError(t, state.handleUpdate(event, &snap)) + require.NoError(t, state.handleUpdate(ctx, event, &snap)) })) }