From 03ab7367a61dabfd03b2e473986818695891444d Mon Sep 17 00:00:00 2001 From: Dan Stough Date: Mon, 22 Apr 2024 14:30:43 -0400 Subject: [PATCH] feat(dataplane): allow token and tenancy information for proxied DNS (#20899) * feat(dataplane): allow token and tenancy information for proxied DNS * changelog --- .changelog/20899.txt | 4 + agent/discovery/query_fetcher_v1.go | 24 ++- agent/discovery/query_fetcher_v1_test.go | 1 + agent/discovery/query_fetcher_v2.go | 1 + agent/dns/context.go | 56 ++++++ agent/dns/context_test.go | 70 ++++++++ agent/dns/discovery_results_fetcher.go | 14 +- agent/dns/discovery_results_fetcher_test.go | 14 +- agent/dns/router.go | 50 +++--- agent/dns/router_test.go | 167 ++++++++++++++++-- agent/grpc-external/services/dns/server_v2.go | 7 +- .../services/dns/server_v2_test.go | 36 +++- 12 files changed, 376 insertions(+), 68 deletions(-) create mode 100644 .changelog/20899.txt create mode 100644 agent/dns/context.go create mode 100644 agent/dns/context_test.go diff --git a/.changelog/20899.txt b/.changelog/20899.txt new file mode 100644 index 0000000000..3823a7514b --- /dev/null +++ b/.changelog/20899.txt @@ -0,0 +1,4 @@ +```release-note:improvement +dns: DNS-over-grpc when using `consul-dataplane` now accepts partition, namespace, token as metadata to default those query parameters. +`consul-dataplane` v1.5+ will send this information automatically. +``` diff --git a/agent/discovery/query_fetcher_v1.go b/agent/discovery/query_fetcher_v1.go index 87c50f93ad..dc897f7728 100644 --- a/agent/discovery/query_fetcher_v1.go +++ b/agent/discovery/query_fetcher_v1.go @@ -116,11 +116,17 @@ func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, e // Nodes are not namespaced, so this is a name error return nil, ErrNotFound } - cfg := f.dynamicConfig.Load().(*V1DataFetcherDynamicConfig) + + // If no datacenter is passed, default to our own + datacenter := cfg.Datacenter + if req.Tenancy.Datacenter != "" { + datacenter = req.Tenancy.Datacenter + } + // Make an RPC request args := &structs.NodeSpecificRequest{ - Datacenter: req.Tenancy.Datacenter, + Datacenter: datacenter, PeerName: req.Tenancy.Peer, Node: req.Name, QueryOptions: structs.QueryOptions{ @@ -299,9 +305,15 @@ func (f *V1DataFetcher) FetchWorkload(ctx Context, req *QueryPayload) (*Result, func (f *V1DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error) { cfg := f.dynamicConfig.Load().(*V1DataFetcherDynamicConfig) + // If no datacenter is passed, default to our own + datacenter := cfg.Datacenter + if req.Tenancy.Datacenter != "" { + datacenter = req.Tenancy.Datacenter + } + // Execute the prepared query. args := structs.PreparedQueryExecuteRequest{ - Datacenter: req.Tenancy.Datacenter, + Datacenter: datacenter, QueryIDOrName: req.Name, QueryOptions: structs.QueryOptions{ Token: ctx.Token, @@ -548,7 +560,11 @@ func (f *V1DataFetcher) fetchServiceBasedOnTenancy(ctx Context, req *QueryPayloa return nil, errors.New("sameness groups are not allowed for service lookups based on tenancy") } - datacenter := req.Tenancy.Datacenter + // If no datacenter is passed, default to our own + datacenter := cfg.Datacenter + if req.Tenancy.Datacenter != "" { + datacenter = req.Tenancy.Datacenter + } if req.Tenancy.Peer != "" { datacenter = "" } diff --git a/agent/discovery/query_fetcher_v1_test.go b/agent/discovery/query_fetcher_v1_test.go index a587bc74ff..450b0cb13a 100644 --- a/agent/discovery/query_fetcher_v1_test.go +++ b/agent/discovery/query_fetcher_v1_test.go @@ -128,6 +128,7 @@ func Test_FetchVirtualIP(t *testing.T) { func Test_FetchEndpoints(t *testing.T) { // set these to confirm that RPC call does not use them for this particular RPC rc := &config.RuntimeConfig{ + Datacenter: "dc2", DNSAllowStale: true, DNSMaxStale: 100, DNSUseCache: true, diff --git a/agent/discovery/query_fetcher_v2.go b/agent/discovery/query_fetcher_v2.go index ac474811fa..dc870e76ad 100644 --- a/agent/discovery/query_fetcher_v2.go +++ b/agent/discovery/query_fetcher_v2.go @@ -347,6 +347,7 @@ func queryTenancyToResourceTenancy(qTenancy QueryTenancy) *pbresource.Tenancy { rTenancy.Namespace = qTenancy.Namespace } // In the case of partition, we have the agent's partition as the fallback. + // That is handled in NormalizeRequest. if qTenancy.Partition != "" { rTenancy.Partition = qTenancy.Partition } diff --git a/agent/dns/context.go b/agent/dns/context.go new file mode 100644 index 0000000000..2054d6316a --- /dev/null +++ b/agent/dns/context.go @@ -0,0 +1,56 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package dns + +import ( + "context" + "fmt" + + "github.com/mitchellh/mapstructure" + "google.golang.org/grpc/metadata" +) + +// Context is used augment a DNS message with Consul-specific metadata. +type Context struct { + Token string `mapstructure:"x-consul-token,omitempty"` + DefaultNamespace string `mapstructure:"x-consul-namespace,omitempty"` + DefaultPartition string `mapstructure:"x-consul-partition,omitempty"` +} + +// NewContextFromGRPCContext returns the request context using the gRPC metadata attached to the +// given context. If there is no gRPC metadata, it returns an empty context. +func NewContextFromGRPCContext(ctx context.Context) (Context, error) { + if ctx == nil { + return Context{}, nil + } + + reqCtx := Context{} + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return reqCtx, nil + } + + m := map[string]string{} + for k, v := range md { + m[k] = v[0] + } + + decoderConfig := &mapstructure.DecoderConfig{ + Metadata: nil, + Result: &reqCtx, + WeaklyTypedInput: true, + } + + decoder, err := mapstructure.NewDecoder(decoderConfig) + if err != nil { + return Context{}, fmt.Errorf("error creating mapstructure decoder: %w", err) + } + + err = decoder.Decode(m) + if err != nil { + return Context{}, fmt.Errorf("error decoding metadata: %w", err) + } + + return reqCtx, nil +} diff --git a/agent/dns/context_test.go b/agent/dns/context_test.go new file mode 100644 index 0000000000..44ad9055d0 --- /dev/null +++ b/agent/dns/context_test.go @@ -0,0 +1,70 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package dns + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" +) + +func TestNewContextFromGRPCContext(t *testing.T) { + t.Parallel() + + md := metadata.MD{} + testMeta := map[string]string{ + "x-consul-token": "test-token", + "x-consul-namespace": "test-namespace", + "x-consul-partition": "test-partition", + } + + for k, v := range testMeta { + md.Set(k, v) + } + testGRPCContext := metadata.NewIncomingContext(context.Background(), md) + + testCases := []struct { + name string + grpcCtx context.Context + expected *Context + error error + }{ + { + name: "nil grpc context", + grpcCtx: nil, + expected: &Context{}, + }, + { + name: "grpc context w/o metadata", + grpcCtx: context.Background(), + expected: &Context{}, + }, + { + name: "grpc context w/ kitchen sink", + grpcCtx: testGRPCContext, + expected: &Context{ + Token: "test-token", + DefaultNamespace: "test-namespace", + DefaultPartition: "test-partition", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, err := NewContextFromGRPCContext(tc.grpcCtx) + if tc.error != nil { + require.Error(t, err) + require.Equal(t, Context{}, &ctx) + require.Equal(t, tc.error, err) + return + } + + require.NotNil(t, ctx) + require.Equal(t, tc.expected, &ctx) + }) + } +} diff --git a/agent/dns/discovery_results_fetcher.go b/agent/dns/discovery_results_fetcher.go index f68f24865c..0a3b70a4bd 100644 --- a/agent/dns/discovery_results_fetcher.go +++ b/agent/dns/discovery_results_fetcher.go @@ -206,20 +206,24 @@ func getQueryTenancy(reqCtx Context, queryType discovery.QueryType, querySuffixe return discovery.QueryTenancy{}, errNameNotFound } - // If we don't have an explicit partition in the request, try the first fallback + // If we don't have an explicit partition/ns in the request, try the first fallback // which was supplied in the request context. The agent's partition will be used as the last fallback // later in the query processor. if labels.Partition == "" { labels.Partition = reqCtx.DefaultPartition } + if labels.Namespace == "" { + labels.Namespace = reqCtx.DefaultNamespace + } + // If we have a sameness group, we can return early without further data massage. if labels.SamenessGroup != "" { return discovery.QueryTenancy{ Namespace: labels.Namespace, Partition: labels.Partition, SamenessGroup: labels.SamenessGroup, - Datacenter: reqCtx.DefaultDatacenter, + // Datacenter is not supported }, nil } @@ -234,19 +238,19 @@ func getQueryTenancy(reqCtx Context, queryType discovery.QueryType, querySuffixe Namespace: labels.Namespace, Partition: labels.Partition, Peer: labels.Peer, - Datacenter: getEffectiveDatacenter(labels, reqCtx.DefaultDatacenter), + Datacenter: getEffectiveDatacenter(labels), }, nil } // getEffectiveDatacenter returns the effective datacenter from the parsed labels. -func getEffectiveDatacenter(labels *parsedLabels, defaultDC string) string { +func getEffectiveDatacenter(labels *parsedLabels) string { switch { case labels.Datacenter != "": return labels.Datacenter case labels.PeerOrDatacenter != "" && labels.Peer != labels.PeerOrDatacenter: return labels.PeerOrDatacenter } - return defaultDC + return "" } // getQueryTypePartsAndSuffixesFromDNSMessage returns the query type, the parts, and suffixes of the query name. diff --git a/agent/dns/discovery_results_fetcher_test.go b/agent/dns/discovery_results_fetcher_test.go index 59e299fe66..01792a0646 100644 --- a/agent/dns/discovery_results_fetcher_test.go +++ b/agent/dns/discovery_results_fetcher_test.go @@ -160,8 +160,7 @@ func Test_buildQueryFromDNSMessage(t *testing.T) { }, }, requestContext: &Context{ - DefaultDatacenter: "default-dc", - DefaultPartition: "default-partition", + DefaultPartition: "default-partition", }, expectedQuery: &discovery.Query{ QueryType: discovery.QueryTypeWorkload, @@ -169,10 +168,9 @@ func Test_buildQueryFromDNSMessage(t *testing.T) { Name: "foo", PortName: "api", Tenancy: discovery.QueryTenancy{ - Namespace: "banana", - Partition: "orange", - Peer: "apple", - Datacenter: "default-dc", + Namespace: "banana", + Partition: "orange", + Peer: "apple", }, }, }, @@ -192,8 +190,7 @@ func Test_buildQueryFromDNSMessage(t *testing.T) { }, }, requestContext: &Context{ - DefaultDatacenter: "default-dc", - DefaultPartition: "default-partition", + DefaultPartition: "default-partition", }, expectedQuery: &discovery.Query{ QueryType: discovery.QueryTypeService, @@ -203,7 +200,6 @@ func Test_buildQueryFromDNSMessage(t *testing.T) { Namespace: "banana", Partition: "orange", SamenessGroup: "apple", - Datacenter: "default-dc", }, }, }, diff --git a/agent/dns/router.go b/agent/dns/router.go index d03beffdfb..9c0175a776 100644 --- a/agent/dns/router.go +++ b/agent/dns/router.go @@ -14,9 +14,8 @@ import ( "github.com/armon/go-metrics" "github.com/armon/go-radix" - "github.com/miekg/dns" - "github.com/hashicorp/go-hclog" + "github.com/miekg/dns" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/discovery" @@ -46,13 +45,6 @@ var ( trailingSpacesRE = regexp.MustCompile(" +$") ) -// Context is used augment a DNS message with Consul-specific metadata. -type Context struct { - Token string - DefaultPartition string - DefaultDatacenter string -} - // RouterDynamicConfig is the dynamic configuration that can be hot-reloaded type RouterDynamicConfig struct { ARecordLimit int @@ -114,13 +106,12 @@ type dnsRecursor interface { // Router replaces miekg/dns.ServeMux with a simpler router that only checks for the 2-3 valid domains // that Consul supports and forwards to a single DiscoveryQueryProcessor handler. If there is no match, it will recurse. type Router struct { - processor DiscoveryQueryProcessor - recursor dnsRecursor - domain string - altDomain string - datacenter string - nodeName string - logger hclog.Logger + processor DiscoveryQueryProcessor + recursor dnsRecursor + domain string + altDomain string + nodeName string + logger hclog.Logger tokenFunc func() string translateAddressFunc func(dc string, addr string, taggedAddresses map[string]string, accept dnsutil.TranslateAddressAccept) string @@ -146,7 +137,6 @@ func NewRouter(cfg Config) (*Router, error) { recursor: newRecursor(logger), domain: domain, altDomain: altDomain, - datacenter: cfg.AgentConfig.Datacenter, logger: logger, nodeName: cfg.AgentConfig.NodeName, tokenFunc: cfg.TokenFunc, @@ -176,6 +166,7 @@ func (r *Router) HandleRequest(req *dns.Msg, reqCtx Context, remoteAddress net.A } r.logger.Trace("received request", "question", req.Question[0].Name, "type", dns.Type(req.Question[0].Qtype).String()) + r.normalizeContext(&reqCtx) defer func(s time.Time, q dns.Question) { metrics.MeasureSinceWithLabels([]string{"dns", "query"}, s, @@ -319,10 +310,9 @@ func (r *Router) trimDomain(questionName string) string { } // ServeDNS implements the miekg/dns.Handler interface. -// This is a standard DNS listener, so we inject a default request context based on the agent's config. +// This is a standard DNS listener. func (r *Router) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { - reqCtx := r.defaultAgentDNSRequestContext() - out := r.HandleRequest(req, reqCtx, w.RemoteAddr()) + out := r.HandleRequest(req, Context{}, w.RemoteAddr()) w.WriteMsg(out) } @@ -420,16 +410,6 @@ func (r *Router) GetConfig() *RouterDynamicConfig { return r.dynamicConfig.Load().(*RouterDynamicConfig) } -// defaultAgentDNSRequestContext returns a default request context based on the agent's config. -func (r *Router) defaultAgentDNSRequestContext() Context { - return Context{ - Token: r.tokenFunc(), - DefaultDatacenter: r.datacenter, - // We don't need to specify the agent's partition here because that will be handled further down the stack - // in the query processor. - } -} - // getErrorFromECSNotGlobalError returns the underlying error from an ECSNotGlobalError, if it exists. func getErrorFromECSNotGlobalError(err error) error { if errors.Is(err, discovery.ErrECSNotGlobal) { @@ -471,6 +451,16 @@ func validateAndNormalizeRequest(req *dns.Msg) error { return nil } +// normalizeContext makes sure context information is populated with agent defaults as needed. +// Right now this is just the ACL token. We do this in the router with the token because DNS doesn't +// allow a token to be passed in the request, and we expect ACL tokens upfront in APIs when they are enabled. +// Tenancy information is left out because it is safe/expected to assume agent defaults in the backend lookup. +func (r *Router) normalizeContext(ctx *Context) { + if ctx.Token == "" { + ctx.Token = r.tokenFunc() + } +} + // stripAnyFailoverSuffix strips off the suffixes that may have been added to the request name. func stripAnyFailoverSuffix(target string) (string, bool) { enableFailover := false diff --git a/agent/dns/router_test.go b/agent/dns/router_test.go index fe1a034bf8..717ee9e16b 100644 --- a/agent/dns/router_test.go +++ b/agent/dns/router_test.go @@ -52,22 +52,6 @@ type HandleTestCase struct { response *dns.Msg } -var testSOA = &dns.SOA{ - Hdr: dns.RR_Header{ - Name: "consul.", - Rrtype: dns.TypeSOA, - Class: dns.ClassINET, - Ttl: 4, - }, - Ns: "ns.consul.", - Mbox: "hostmaster.consul.", - Serial: uint32(time.Now().Unix()), - Refresh: 1, - Retry: 2, - Expire: 3, - Minttl: 4, -} - func Test_HandleRequest_Validation(t *testing.T) { testCases := []HandleTestCase{ { @@ -93,6 +77,157 @@ func Test_HandleRequest_Validation(t *testing.T) { Extra: nil, }, }, + // Context Tests + { + name: "When a request context is provided, use those field in the query", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "foo.service.consul.", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + requestContext: &Context{ + Token: "test-token", + DefaultNamespace: "test-namespace", + DefaultPartition: "test-partition", + }, + configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { + result := []*discovery.Result{ + { + Type: discovery.ResultTypeNode, + Node: &discovery.Location{Name: "foo", Address: "1.2.3.4"}, + Tenancy: discovery.ResultTenancy{ + Namespace: "test-namespace", + Partition: "test-partition", + }, + }, + } + + fetcher.(*discovery.MockCatalogDataFetcher). + On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything). + Return(result, nil). + Run(func(args mock.Arguments) { + ctx := args.Get(0).(discovery.Context) + req := args.Get(1).(*discovery.QueryPayload) + reqType := args.Get(2).(discovery.LookupType) + + require.Equal(t, "test-token", ctx.Token) + + require.Equal(t, "foo", req.Name) + require.Equal(t, "test-namespace", req.Tenancy.Namespace) + require.Equal(t, "test-partition", req.Tenancy.Partition) + + require.Equal(t, discovery.LookupTypeService, reqType) + }) + }, + validateAndNormalizeExpected: true, + response: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Response: true, + Authoritative: true, + }, + Compress: true, + Question: []dns.Question{ + { + Name: "foo.service.consul.", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + Answer: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{ + Name: "foo.service.consul.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 123, + }, + A: net.ParseIP("1.2.3.4"), + }, + }, + }, + }, + { + name: "When a request context is provided, values do not override explicit tenancy", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "foo.service.bar.ns.baz.ap.consul.", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + requestContext: &Context{ + Token: "test-token", + DefaultNamespace: "test-namespace", + DefaultPartition: "test-partition", + }, + configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { + result := []*discovery.Result{ + { + Type: discovery.ResultTypeNode, + Node: &discovery.Location{Name: "foo", Address: "1.2.3.4"}, + Tenancy: discovery.ResultTenancy{ + Namespace: "bar", + Partition: "baz", + }, + }, + } + + fetcher.(*discovery.MockCatalogDataFetcher). + On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything). + Return(result, nil). + Run(func(args mock.Arguments) { + ctx := args.Get(0).(discovery.Context) + req := args.Get(1).(*discovery.QueryPayload) + reqType := args.Get(2).(discovery.LookupType) + + require.Equal(t, "test-token", ctx.Token) + + require.Equal(t, "foo", req.Name) + require.Equal(t, "bar", req.Tenancy.Namespace) + require.Equal(t, "baz", req.Tenancy.Partition) + + require.Equal(t, discovery.LookupTypeService, reqType) + }) + }, + validateAndNormalizeExpected: true, + response: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Response: true, + Authoritative: true, + }, + Compress: true, + Question: []dns.Question{ + { + Name: "foo.service.bar.ns.baz.ap.consul.", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + Answer: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{ + Name: "foo.service.bar.ns.baz.ap.consul.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 123, + }, + A: net.ParseIP("1.2.3.4"), + }, + }, + }, + }, } for _, tc := range testCases { diff --git a/agent/grpc-external/services/dns/server_v2.go b/agent/grpc-external/services/dns/server_v2.go index 64cf22012f..848361535f 100644 --- a/agent/grpc-external/services/dns/server_v2.go +++ b/agent/grpc-external/services/dns/server_v2.go @@ -72,9 +72,10 @@ func (s *ServerV2) Query(ctx context.Context, req *pbdns.QueryRequest) (*pbdns.Q return nil, status.Error(codes.Internal, fmt.Sprintf("failure decoding dns request: %s", err.Error())) } - // TODO (v2-dns): parse token and other context metadata from the grpc request/metadata (NET-7885) - reqCtx := agentdns.Context{ - Token: s.TokenFunc(), + reqCtx, err := agentdns.NewContextFromGRPCContext(ctx) + if err != nil { + s.Logger.Error("error parsing DNS context from grpc metadata", "err", err) + return nil, status.Error(codes.Internal, fmt.Sprintf("error parsing DNS context from grpc metadata: %s", err.Error())) } resp := s.DNSRouter.HandleRequest(msg, reqCtx, remote) diff --git a/agent/grpc-external/services/dns/server_v2_test.go b/agent/grpc-external/services/dns/server_v2_test.go index 7001029353..06c7d4f96a 100644 --- a/agent/grpc-external/services/dns/server_v2_test.go +++ b/agent/grpc-external/services/dns/server_v2_test.go @@ -10,6 +10,8 @@ import ( "github.com/hashicorp/go-hclog" "github.com/miekg/dns" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" agentdns "github.com/hashicorp/consul/agent/dns" "github.com/hashicorp/consul/proto-public/pbdns" @@ -50,6 +52,7 @@ func (s *DNSTestSuite) TestProxy_V2Success() { question string configureRouter func(router *agentdns.MockDNSRouter) clientQuery func(qR *pbdns.QueryRequest) + metadata map[string]string expectedErr error }{ @@ -73,6 +76,28 @@ func (s *DNSTestSuite) TestProxy_V2Success() { qR.Protocol = pbdns.Protocol_PROTOCOL_TCP }, }, + "happy path with context variables set": { + question: "abc.com.", + configureRouter: func(router *agentdns.MockDNSRouter) { + router.On("HandleRequest", mock.Anything, mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { + ctx, ok := args.Get(1).(agentdns.Context) + require.True(s.T(), ok, "error casting to agentdns.Context") + require.Equal(s.T(), "test-token", ctx.Token, "token not set in context") + require.Equal(s.T(), "test-namespace", ctx.DefaultNamespace, "namespace not set in context") + require.Equal(s.T(), "test-partition", ctx.DefaultPartition, "partition not set in context") + }). + Return(basicResponse(), nil) + }, + clientQuery: func(qR *pbdns.QueryRequest) { + qR.Protocol = pbdns.Protocol_PROTOCOL_UDP + }, + metadata: map[string]string{ + "x-consul-token": "test-token", + "x-consul-namespace": "test-namespace", + "x-consul-partition": "test-partition", + }, + }, "No protocol set": { question: "abc.com.", clientQuery: func(qR *pbdns.QueryRequest) {}, @@ -108,9 +133,18 @@ func (s *DNSTestSuite) TestProxy_V2Success() { bytes, _ := req.Pack() + ctx := context.Background() + if len(tc.metadata) > 0 { + md := metadata.MD{} + for k, v := range tc.metadata { + md.Set(k, v) + } + ctx = metadata.NewOutgoingContext(ctx, md) + } + clientReq := &pbdns.QueryRequest{Msg: bytes} tc.clientQuery(clientReq) - clientResp, err := client.Query(context.Background(), clientReq) + clientResp, err := client.Query(ctx, clientReq) if tc.expectedErr != nil { s.Require().Error(err, "no errror calling gRPC endpoint") s.Require().ErrorContains(err, tc.expectedErr.Error())