diff --git a/agent/agent.go b/agent/agent.go index 9dd1b1cade..ef73242cd9 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1111,7 +1111,7 @@ func (a *Agent) listenAndServeV2DNS() error { // Check the catalog version and decide which implementation of the data fetcher to implement if a.baseDeps.UseV2Resources() { - a.catalogDataFetcher = discovery.NewV2DataFetcher(a.config) + a.catalogDataFetcher = discovery.NewV2DataFetcher(a.config, a.delegate.ResourceServiceClient(), a.logger.Named("catalog-data-fetcher")) } else { a.catalogDataFetcher = discovery.NewV1DataFetcher(a.config, a.AgentEnterpriseMeta(), diff --git a/agent/discovery/discovery.go b/agent/discovery/discovery.go index 92e6644d2a..0439ca20bc 100644 --- a/agent/discovery/discovery.go +++ b/agent/discovery/discovery.go @@ -7,14 +7,13 @@ import ( "fmt" "net" - "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/config" - "github.com/hashicorp/consul/agent/structs" ) var ( - ErrNoData = fmt.Errorf("no data") ErrECSNotGlobal = fmt.Errorf("ECS response is not global") + ErrNoData = fmt.Errorf("no data") + ErrNotFound = fmt.Errorf("not found") ErrNotSupported = fmt.Errorf("not supported") ) @@ -65,18 +64,16 @@ const ( // Context is used to pass information about the request. type Context struct { - Token string - DefaultPartition string - DefaultNamespace string - DefaultLocality *structs.Locality + Token string } // QueryTenancy is used to filter catalog data based on tenancy. type QueryTenancy struct { - EnterpriseMeta acl.EnterpriseMeta - SamenessGroup string - Peer string - Datacenter string + Namespace string + Partition string + SamenessGroup string + Peer string + Datacenter string } // QueryPayload represents all information needed by the data backend @@ -89,7 +86,7 @@ type QueryPayload struct { Tenancy QueryTenancy // tenancy includes any additional labels specified before the domain // v2 fields only - DisableFailover bool + EnableFailover bool } // ResultType indicates the Consul resource that a discovery record represents. @@ -107,11 +104,12 @@ const ( // It is the responsibility of the DNS encoder to know what to do with // each Result, based on the query type. type Result struct { - Address string // A/AAAA/CNAME records - could be used in the Extra section. CNAME is required to handle hostname addresses in workloads & nodes. - Weight uint32 // SRV queries - Port uint32 // SRV queries - Metadata map[string]string // Used to collect metadata into TXT Records - Type ResultType // Used to reconstruct the fqdn name of the resource + Address string // A/AAAA/CNAME records - could be used in the Extra section. CNAME is required to handle hostname addresses in workloads & nodes. + Weight uint32 // SRV queries + PortName string // Used to generate a fgdn when a specifc port was queried + PortNumber uint32 // SRV queries + Metadata map[string]string // Used to collect metadata into TXT Records + Type ResultType // Used to reconstruct the fqdn name of the resource // Used in SRV & PTR queries to point at an A/AAAA Record. Target string @@ -121,9 +119,10 @@ type Result struct { // ResultTenancy is used to reconstruct the fqdn name of the resource. type ResultTenancy struct { - PeerName string - Datacenter string - EnterpriseMeta acl.EnterpriseMeta // TODO (v2-dns): need something that is compatible with the V2 catalog + Namespace string + Partition string + PeerName string + Datacenter string } // LookupType is used by the CatalogDataFetcher to properly filter endpoints. @@ -138,6 +137,8 @@ const ( // CatalogDataFetcher is an interface that abstracts data collection // for Discovery queries. It is assumed that the instantiation also // includes any agent configuration that influences catalog queries. +// +//go:generate mockery --name CatalogDataFetcher --inpackage type CatalogDataFetcher interface { // LoadConfig is used to hot-reload the data fetcher with new agent config. LoadConfig(config *config.RuntimeConfig) @@ -162,6 +163,13 @@ type CatalogDataFetcher interface { // FetchPreparedQuery evaluates the results of a prepared query. // deprecated in V2 FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error) + + // NormalizeRequest mutates the original request based on data fetcher configuration, like + // defaulting tenancy to the agent's partition. + NormalizeRequest(req *QueryPayload) + + // ValidateRequest throws an error is any of the input fields are invalid for this data fetcher. + ValidateRequest(ctx Context, req *QueryPayload) error } // QueryProcessor is used to process a Discovery Query and return the results. @@ -178,6 +186,12 @@ func NewQueryProcessor(dataFetcher CatalogDataFetcher) *QueryProcessor { // QueryByName is used to look up a service, node, workload, or prepared query. func (p *QueryProcessor) QueryByName(query *Query, ctx Context) ([]*Result, error) { + if err := p.dataFetcher.ValidateRequest(ctx, &query.QueryPayload); err != nil { + return nil, err + } + + p.dataFetcher.NormalizeRequest(&query.QueryPayload) + switch query.QueryType { case QueryTypeNode: return p.dataFetcher.FetchNodes(ctx, &query.QueryPayload) diff --git a/agent/discovery/discovery_test.go b/agent/discovery/discovery_test.go new file mode 100644 index 0000000000..af7fd148b3 --- /dev/null +++ b/agent/discovery/discovery_test.go @@ -0,0 +1,221 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package discovery + +import ( + "errors" + "net" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +var ( + testContext = Context{ + Token: "bar", + } + + testErr = errors.New("test error") + + testIP = net.ParseIP("1.2.3.4") + + testPayload = QueryPayload{ + Name: "foo", + } + + testResult = &Result{ + Address: "1.2.3.4", + Type: ResultTypeNode, // This isn't correct for some test cases, but we are only asserting the right data fetcher functions are called + Target: "foo", + } +) + +func TestQueryByName(t *testing.T) { + + type testCase struct { + name string + reqType QueryType + configureDataFetcher func(*testing.T, *MockCatalogDataFetcher) + expectedResults []*Result + expectedError error + } + + run := func(t *testing.T, tc testCase) { + + fetcher := NewMockCatalogDataFetcher(t) + tc.configureDataFetcher(t, fetcher) + + qp := NewQueryProcessor(fetcher) + + q := Query{ + QueryType: tc.reqType, + QueryPayload: testPayload, + } + + results, err := qp.QueryByName(&q, testContext) + if tc.expectedError != nil { + require.Error(t, err) + require.True(t, errors.Is(err, tc.expectedError)) + require.Nil(t, results) + return + } + require.NoError(t, err) + require.Equal(t, tc.expectedResults, results) + } + + testCases := []testCase{ + { + name: "query node", + reqType: QueryTypeNode, + configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) { + + fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil) + fetcher.On("NormalizeRequest", mock.Anything) + fetcher.On("FetchNodes", mock.Anything, mock.Anything).Return([]*Result{testResult}, nil) + }, + expectedResults: []*Result{testResult}, + }, + { + name: "query service", + reqType: QueryTypeService, + configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) { + + fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil) + fetcher.On("NormalizeRequest", mock.Anything) + fetcher.On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).Return([]*Result{testResult}, nil) + }, + expectedResults: []*Result{testResult}, + }, + { + name: "query connect", + reqType: QueryTypeConnect, + configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) { + + fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil) + fetcher.On("NormalizeRequest", mock.Anything) + fetcher.On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).Return([]*Result{testResult}, nil) + }, + expectedResults: []*Result{testResult}, + }, + { + name: "query ingress", + reqType: QueryTypeIngress, + configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) { + + fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil) + fetcher.On("NormalizeRequest", mock.Anything) + fetcher.On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything).Return([]*Result{testResult}, nil) + }, + expectedResults: []*Result{testResult}, + }, + { + name: "query virtual ip", + reqType: QueryTypeVirtual, + configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) { + + fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil) + fetcher.On("NormalizeRequest", mock.Anything) + fetcher.On("FetchVirtualIP", mock.Anything, mock.Anything).Return(testResult, nil) + }, + expectedResults: []*Result{testResult}, + }, + { + name: "query workload", + reqType: QueryTypeWorkload, + configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) { + + fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil) + fetcher.On("NormalizeRequest", mock.Anything) + fetcher.On("FetchWorkload", mock.Anything, mock.Anything).Return(testResult, nil) + }, + expectedResults: []*Result{testResult}, + }, + { + name: "query prepared query", + reqType: QueryTypePreparedQuery, + configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) { + + fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil) + fetcher.On("NormalizeRequest", mock.Anything) + fetcher.On("FetchPreparedQuery", mock.Anything, mock.Anything).Return([]*Result{testResult}, nil) + }, + expectedResults: []*Result{testResult}, + }, + { + name: "returns error from validation", + reqType: QueryTypePreparedQuery, + configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) { + fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(testErr) + }, + expectedError: testErr, + }, + { + name: "returns error from fetcher", + reqType: QueryTypePreparedQuery, + configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) { + fetcher.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil) + fetcher.On("NormalizeRequest", mock.Anything) + fetcher.On("FetchPreparedQuery", mock.Anything, mock.Anything).Return(nil, testErr) + }, + expectedError: testErr, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + run(t, tc) + }) + } +} + +func TestQueryByIP(t *testing.T) { + type testCase struct { + name string + configureDataFetcher func(*testing.T, *MockCatalogDataFetcher) + expectedResults []*Result + expectedError error + } + + run := func(t *testing.T, tc testCase) { + + fetcher := NewMockCatalogDataFetcher(t) + tc.configureDataFetcher(t, fetcher) + + qp := NewQueryProcessor(fetcher) + + results, err := qp.QueryByIP(testIP, testContext) + if tc.expectedError != nil { + require.Error(t, err) + require.True(t, errors.Is(err, tc.expectedError)) + require.Nil(t, results) + return + } + require.NoError(t, err) + require.Equal(t, tc.expectedResults, results) + } + + testCases := []testCase{ + { + name: "query by IP", + configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) { + fetcher.On("FetchRecordsByIp", mock.Anything, mock.Anything).Return([]*Result{testResult}, nil) + }, + expectedResults: []*Result{testResult}, + }, + { + name: "returns error from fetcher", + configureDataFetcher: func(t *testing.T, fetcher *MockCatalogDataFetcher) { + fetcher.On("FetchRecordsByIp", mock.Anything, mock.Anything).Return(nil, testErr) + }, + expectedError: testErr, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + run(t, tc) + }) + } +} diff --git a/agent/discovery/mock_CatalogDataFetcher.go b/agent/discovery/mock_CatalogDataFetcher.go index 5a035ecfe3..f80a6010d2 100644 --- a/agent/discovery/mock_CatalogDataFetcher.go +++ b/agent/discovery/mock_CatalogDataFetcher.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.37.1. DO NOT EDIT. package discovery @@ -175,6 +175,25 @@ func (_m *MockCatalogDataFetcher) LoadConfig(_a0 *config.RuntimeConfig) { _m.Called(_a0) } +// NormalizeRequest provides a mock function with given fields: req +func (_m *MockCatalogDataFetcher) NormalizeRequest(req *QueryPayload) { + _m.Called(req) +} + +// ValidateRequest provides a mock function with given fields: ctx, req +func (_m *MockCatalogDataFetcher) ValidateRequest(ctx Context, req *QueryPayload) error { + ret := _m.Called(ctx, req) + + var r0 error + if rf, ok := ret.Get(0).(func(Context, *QueryPayload) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // NewMockCatalogDataFetcher creates a new instance of MockCatalogDataFetcher. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockCatalogDataFetcher(t interface { diff --git a/agent/discovery/query_fetcher_v1.go b/agent/discovery/query_fetcher_v1.go index 0b3318a372..81c73dca66 100644 --- a/agent/discovery/query_fetcher_v1.go +++ b/agent/discovery/query_fetcher_v1.go @@ -12,15 +12,15 @@ import ( "time" "github.com/armon/go-metrics" - cachetype "github.com/hashicorp/consul/agent/cache-types" - "github.com/hashicorp/consul/api" "github.com/hashicorp/go-hclog" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/cache" + cachetype "github.com/hashicorp/consul/agent/cache-types" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/api" ) const ( @@ -31,16 +31,14 @@ const ( // v1DataFetcherDynamicConfig is used to store the dynamic configuration of the V1 data fetcher. type v1DataFetcherDynamicConfig struct { // Default request tenancy - defaultEntMeta acl.EnterpriseMeta - datacenter string + datacenter string // Catalog configuration - allowStale bool - maxStale time.Duration - useCache bool - cacheMaxAge time.Duration - onlyPassing bool - enterpriseDNSConfig EnterpriseDNSConfig + allowStale bool + maxStale time.Duration + useCache bool + cacheMaxAge time.Duration + onlyPassing bool } // V1DataFetcher is used to fetch data from the V1 catalog. @@ -82,15 +80,12 @@ func NewV1DataFetcher(config *config.RuntimeConfig, // LoadConfig loads the configuration for the V1 data fetcher. func (f *V1DataFetcher) LoadConfig(config *config.RuntimeConfig) { dynamicConfig := &v1DataFetcherDynamicConfig{ - allowStale: config.DNSAllowStale, - maxStale: config.DNSMaxStale, - useCache: config.DNSUseCache, - cacheMaxAge: config.DNSCacheMaxAge, - onlyPassing: config.DNSOnlyPassing, - enterpriseDNSConfig: GetEnterpriseDNSConfig(config), - datacenter: config.Datacenter, - // TODO (v2-dns): make this work - //defaultEntMeta: config.EnterpriseRuntimeConfig.DefaultEntMeta, + allowStale: config.DNSAllowStale, + maxStale: config.DNSMaxStale, + useCache: config.DNSUseCache, + cacheMaxAge: config.DNSCacheMaxAge, + onlyPassing: config.DNSOnlyPassing, + datacenter: config.Datacenter, } f.dynamicConfig.Store(dynamicConfig) } @@ -107,7 +102,7 @@ func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, e Token: ctx.Token, AllowStale: cfg.allowStale, }, - EnterpriseMeta: req.Tenancy.EnterpriseMeta, + EnterpriseMeta: queryTenancyToEntMeta(req.Tenancy), } out, err := f.fetchNode(cfg, args) if err != nil { @@ -128,8 +123,9 @@ func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, e Metadata: node.Meta, Target: node.Node, Tenancy: ResultTenancy{ - EnterpriseMeta: cfg.defaultEntMeta, - Datacenter: cfg.datacenter, + // Namespace is not required because nodes are not namespaced + Partition: node.GetEnterpriseMeta().PartitionOrDefault(), + Datacenter: node.Datacenter, }, }) @@ -155,7 +151,7 @@ func (f *V1DataFetcher) FetchVirtualIP(ctx Context, req *QueryPayload) (*Result, // within a DC, therefore their uniqueness is not guaranteed globally. PeerName: req.Tenancy.Peer, ServiceName: req.Name, - EnterpriseMeta: req.Tenancy.EnterpriseMeta, + EnterpriseMeta: queryTenancyToEntMeta(req.Tenancy), QueryOptions: structs.QueryOptions{ Token: ctx.Token, }, @@ -176,6 +172,10 @@ func (f *V1DataFetcher) FetchVirtualIP(ctx Context, req *QueryPayload) (*Result, // FetchRecordsByIp is used for PTR requests to look up a service/node from an IP. // The search is performed in the agent's partition and over all namespaces (or those allowed by the ACL token). func (f *V1DataFetcher) FetchRecordsByIp(reqCtx Context, ip net.IP) ([]*Result, error) { + if ip == nil { + return nil, ErrNotSupported + } + configCtx := f.dynamicConfig.Load().(*v1DataFetcherDynamicConfig) targetIP := ip.String() @@ -200,8 +200,9 @@ func (f *V1DataFetcher) FetchRecordsByIp(reqCtx Context, ip net.IP) ([]*Result, Type: ResultTypeNode, Target: n.Node, Tenancy: ResultTenancy{ - EnterpriseMeta: f.defaultEnterpriseMeta, - Datacenter: configCtx.datacenter, + Namespace: f.defaultEnterpriseMeta.NamespaceOrDefault(), + Partition: f.defaultEnterpriseMeta.PartitionOrDefault(), + Datacenter: configCtx.datacenter, }, }) return results, nil @@ -229,8 +230,9 @@ func (f *V1DataFetcher) FetchRecordsByIp(reqCtx Context, ip net.IP) ([]*Result, Type: ResultTypeService, Target: n.ServiceName, Tenancy: ResultTenancy{ - EnterpriseMeta: f.defaultEnterpriseMeta, - Datacenter: configCtx.datacenter, + Namespace: f.defaultEnterpriseMeta.NamespaceOrDefault(), + Partition: f.defaultEnterpriseMeta.PartitionOrDefault(), + Datacenter: configCtx.datacenter, }, }) return results, nil @@ -257,6 +259,16 @@ func (f *V1DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*R return nil, nil } +func (f *V1DataFetcher) ValidateRequest(_ Context, req *QueryPayload) error { + if req.EnableFailover { + return ErrNotSupported + } + if req.PortName != "" { + return ErrNotSupported + } + return validateEnterpriseTenancy(req.Tenancy) +} + // fetchNode is used to look up a node in the Consul catalog within NodeServices. // If the config is set to UseCache, it will get the record from the agent cache. func (f *V1DataFetcher) fetchNode(cfg *v1DataFetcherDynamicConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) { @@ -336,7 +348,7 @@ func (f *V1DataFetcher) fetchServiceBasedOnTenancy(ctx Context, req *QueryPayloa UseCache: cfg.useCache, MaxStaleDuration: cfg.maxStale, }, - EnterpriseMeta: req.Tenancy.EnterpriseMeta, + EnterpriseMeta: queryTenancyToEntMeta(req.Tenancy), } out, _, err := f.rpcFuncForServiceNodes(context.TODO(), args) @@ -365,15 +377,16 @@ func (f *V1DataFetcher) fetchServiceBasedOnTenancy(ctx Context, req *QueryPayloa address, target, resultType := getAddressTargetAndResultType(node) results = append(results, &Result{ - Address: address, - Type: resultType, - Target: target, - Weight: uint32(findWeight(node)), - Port: uint32(f.translateServicePortFunc(node.Node.Datacenter, node.Service.Port, node.Service.TaggedAddresses)), - Metadata: node.Node.Meta, + Address: address, + Type: resultType, + Target: target, + Weight: uint32(findWeight(node)), + PortNumber: uint32(f.translateServicePortFunc(node.Node.Datacenter, node.Service.Port, node.Service.TaggedAddresses)), + Metadata: node.Node.Meta, Tenancy: ResultTenancy{ - EnterpriseMeta: cfg.defaultEntMeta, - Datacenter: cfg.datacenter, + Namespace: node.Service.NamespaceOrEmpty(), + Partition: node.Service.PartitionOrEmpty(), + Datacenter: node.Node.Datacenter, }, }) } diff --git a/agent/discovery/query_fetcher_v1_ce.go b/agent/discovery/query_fetcher_v1_ce.go index 6540dea7fe..2bb2a774dd 100644 --- a/agent/discovery/query_fetcher_v1_ce.go +++ b/agent/discovery/query_fetcher_v1_ce.go @@ -8,8 +8,26 @@ package discovery import ( "errors" "fmt" + + "github.com/hashicorp/consul/acl" ) +func (f *V1DataFetcher) NormalizeRequest(req *QueryPayload) { + // Nothing to do for CE + return +} + +func validateEnterpriseTenancy(req QueryTenancy) error { + if req.Namespace != "" || req.Partition != "" { + return ErrNotSupported + } + return nil +} + +func queryTenancyToEntMeta(_ QueryTenancy) acl.EnterpriseMeta { + return acl.EnterpriseMeta{} +} + // fetchServiceFromSamenessGroup fetches a service from a sameness group. func (f *V1DataFetcher) fetchServiceFromSamenessGroup(ctx Context, req *QueryPayload, cfg *v1DataFetcherDynamicConfig) ([]*Result, error) { f.logger.Debug(fmt.Sprintf("fetchServiceFromSamenessGroup - req: %+v", req)) diff --git a/agent/discovery/query_fetcher_v1_ce_test.go b/agent/discovery/query_fetcher_v1_ce_test.go index 7376e50560..717475c9dc 100644 --- a/agent/discovery/query_fetcher_v1_ce_test.go +++ b/agent/discovery/query_fetcher_v1_ce_test.go @@ -5,7 +5,7 @@ package discovery -import "github.com/hashicorp/consul/acl" - -// defaultEntMeta is the default enterprise meta used for testing. -var defaultEntMeta = acl.EnterpriseMeta{} +const ( + defaultTestNamespace = "" + defaultTestPartition = "" +) diff --git a/agent/discovery/query_fetcher_v1_test.go b/agent/discovery/query_fetcher_v1_test.go index 9abcccf160..703548f3e5 100644 --- a/agent/discovery/query_fetcher_v1_test.go +++ b/agent/discovery/query_fetcher_v1_test.go @@ -9,12 +9,11 @@ import ( "testing" "time" - "github.com/hashicorp/consul/agent/cache" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/cache" cachetype "github.com/hashicorp/consul/agent/cache-types" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/structs" @@ -43,8 +42,9 @@ func Test_FetchVirtualIP(t *testing.T) { queryPayload: &QueryPayload{ Name: "db", Tenancy: QueryTenancy{ - Peer: "test-peer", - EnterpriseMeta: defaultEntMeta, + Peer: "test-peer", + Namespace: defaultTestNamespace, + Partition: defaultTestPartition, }, }, context: Context{ @@ -61,9 +61,9 @@ func Test_FetchVirtualIP(t *testing.T) { queryPayload: &QueryPayload{ Name: "db", Tenancy: QueryTenancy{ - Peer: "test-peer", - EnterpriseMeta: defaultEntMeta, - }, + Peer: "test-peer", + Namespace: defaultTestNamespace, + Partition: defaultTestPartition}, }, context: Context{ Token: "test-token", @@ -90,7 +90,8 @@ func Test_FetchVirtualIP(t *testing.T) { // validate RPC options are set correctly from the queryPayload and context require.Equal(t, tc.queryPayload.Tenancy.Peer, req.PeerName) - require.Equal(t, tc.queryPayload.Tenancy.EnterpriseMeta, req.EnterpriseMeta) + require.Equal(t, tc.queryPayload.Tenancy.Namespace, req.EnterpriseMeta.NamespaceOrEmpty()) + require.Equal(t, tc.queryPayload.Tenancy.Partition, req.EnterpriseMeta.PartitionOrEmpty()) require.Equal(t, tc.context.Token, req.QueryOptions.Token) if tc.expectedErr == nil { @@ -143,7 +144,8 @@ func Test_FetchEndpoints(t *testing.T) { queryPayload: &QueryPayload{ Name: "service-name", Tenancy: QueryTenancy{ - EnterpriseMeta: defaultEntMeta, + Namespace: defaultTestNamespace, + Partition: defaultTestPartition, }, }, rpcFuncForServiceNodes: func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) { @@ -151,12 +153,14 @@ func Test_FetchEndpoints(t *testing.T) { Nodes: []structs.CheckServiceNode{ { Node: &structs.Node{ - Address: "node-address", - Node: "node-name", + Address: "node-address", + Node: "node-name", + Partition: defaultTestPartition, }, Service: &structs.NodeService{ - Address: "127.0.0.1", - Service: "service-name", + Address: "127.0.0.1", + Service: "service-name", + EnterpriseMeta: acl.NewEnterpriseMetaWithPartition(defaultTestPartition, defaultTestNamespace), }, }, }, @@ -171,6 +175,10 @@ func Test_FetchEndpoints(t *testing.T) { Target: "service-name", Type: ResultTypeService, Weight: 1, + Tenancy: ResultTenancy{ + Namespace: defaultTestNamespace, + Partition: defaultTestPartition, + }, }, }, expectedErr: nil, @@ -180,7 +188,8 @@ func Test_FetchEndpoints(t *testing.T) { queryPayload: &QueryPayload{ Name: "service-name", Tenancy: QueryTenancy{ - EnterpriseMeta: defaultEntMeta, + Namespace: defaultTestNamespace, + Partition: defaultTestPartition, }, }, rpcFuncForServiceNodes: func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) { @@ -188,12 +197,14 @@ func Test_FetchEndpoints(t *testing.T) { Nodes: []structs.CheckServiceNode{ { Node: &structs.Node{ - Address: "node-address", - Node: "node-name", + Address: "node-address", + Node: "node-name", + Partition: defaultTestPartition, }, Service: &structs.NodeService{ - Address: "2001:db8:1:2:cafe::1337", - Service: "service-name", + Address: "2001:db8:1:2:cafe::1337", + Service: "service-name", + EnterpriseMeta: acl.NewEnterpriseMetaWithPartition(defaultTestPartition, defaultTestNamespace), }, }, }, @@ -208,6 +219,10 @@ func Test_FetchEndpoints(t *testing.T) { Target: "service-name", Type: ResultTypeService, Weight: 1, + Tenancy: ResultTenancy{ + Namespace: defaultTestNamespace, + Partition: defaultTestPartition, + }, }, }, expectedErr: nil, @@ -217,7 +232,8 @@ func Test_FetchEndpoints(t *testing.T) { queryPayload: &QueryPayload{ Name: "service-name", Tenancy: QueryTenancy{ - EnterpriseMeta: defaultEntMeta, + Namespace: defaultTestNamespace, + Partition: defaultTestPartition, }, }, rpcFuncForServiceNodes: func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) { @@ -225,12 +241,14 @@ func Test_FetchEndpoints(t *testing.T) { Nodes: []structs.CheckServiceNode{ { Node: &structs.Node{ - Address: "node-address", - Node: "node-name", + Address: "node-address", + Node: "node-name", + Partition: defaultTestPartition, }, Service: &structs.NodeService{ - Address: "foo", - Service: "service-name", + Address: "foo", + Service: "service-name", + EnterpriseMeta: acl.NewEnterpriseMetaWithPartition(defaultTestPartition, defaultTestNamespace), }, }, }, @@ -245,6 +263,10 @@ func Test_FetchEndpoints(t *testing.T) { Target: "foo", Type: ResultTypeNode, Weight: 1, + Tenancy: ResultTenancy{ + Namespace: defaultTestNamespace, + Partition: defaultTestPartition, + }, }, }, expectedErr: nil, @@ -254,7 +276,8 @@ func Test_FetchEndpoints(t *testing.T) { queryPayload: &QueryPayload{ Name: "service-name", Tenancy: QueryTenancy{ - EnterpriseMeta: defaultEntMeta, + Namespace: defaultTestNamespace, + Partition: defaultTestPartition, }, }, rpcFuncForServiceNodes: func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) { @@ -262,12 +285,14 @@ func Test_FetchEndpoints(t *testing.T) { Nodes: []structs.CheckServiceNode{ { Node: &structs.Node{ - Address: "node-address", - Node: "node-name", + Address: "node-address", + Node: "node-name", + Partition: defaultTestPartition, }, Service: &structs.NodeService{ - Address: "", - Service: "service-name", + Address: "", + Service: "service-name", + EnterpriseMeta: acl.NewEnterpriseMetaWithPartition(defaultTestPartition, defaultTestNamespace), }, }, }, @@ -282,6 +307,10 @@ func Test_FetchEndpoints(t *testing.T) { Target: "node-name", Type: ResultTypeNode, Weight: 1, + Tenancy: ResultTenancy{ + Namespace: defaultTestNamespace, + Partition: defaultTestPartition, + }, }, }, expectedErr: nil, diff --git a/agent/discovery/query_fetcher_v2.go b/agent/discovery/query_fetcher_v2.go index 0bcd69d62f..5371b6f4b0 100644 --- a/agent/discovery/query_fetcher_v2.go +++ b/agent/discovery/query_fetcher_v2.go @@ -4,10 +4,22 @@ package discovery import ( + "context" + "fmt" "net" + "strings" "sync/atomic" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/consul/agent/config" + "github.com/hashicorp/consul/internal/resource" + pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" + "github.com/hashicorp/consul/proto-public/pbresource" ) // v2DataFetcherDynamicConfig is used to store the dynamic configuration of the V2 data fetcher. @@ -17,12 +29,22 @@ type v2DataFetcherDynamicConfig struct { // V2DataFetcher is used to fetch data from the V2 catalog. type V2DataFetcher struct { + client pbresource.ResourceServiceClient + logger hclog.Logger + + // Requests inherit the partition of the agent unless otherwise specified. + defaultPartition string + dynamicConfig atomic.Value } // NewV2DataFetcher creates a new V2 data fetcher. -func NewV2DataFetcher(config *config.RuntimeConfig) *V2DataFetcher { - f := &V2DataFetcher{} +func NewV2DataFetcher(config *config.RuntimeConfig, client pbresource.ResourceServiceClient, logger hclog.Logger) *V2DataFetcher { + f := &V2DataFetcher{ + client: client, + logger: logger, + defaultPartition: config.PartitionOrDefault(), + } f.LoadConfig(config) return f } @@ -35,14 +57,13 @@ func (f *V2DataFetcher) LoadConfig(config *config.RuntimeConfig) { f.dynamicConfig.Store(dynamicConfig) } -// TODO (v2-dns): Implementation of the V2 data fetcher - // FetchNodes fetches A/AAAA/CNAME func (f *V2DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, error) { return nil, nil } // FetchEndpoints fetches records for A/AAAA/CNAME or SRV requests for services +// TODO (v2-dns): Validate lookupType func (f *V2DataFetcher) FetchEndpoints(ctx Context, req *QueryPayload, lookupType LookupType) ([]*Result, error) { return nil, nil } @@ -53,14 +74,81 @@ func (f *V2DataFetcher) FetchVirtualIP(ctx Context, req *QueryPayload) (*Result, } // FetchRecordsByIp is used for PTR requests to look up a service/node from an IP. +// TODO (v2-dns): Validate non-nil IP func (f *V2DataFetcher) FetchRecordsByIp(ctx Context, ip net.IP) ([]*Result, error) { return nil, nil } // FetchWorkload is used to fetch a single workload from the V2 catalog. // V2-only. -func (f *V2DataFetcher) FetchWorkload(ctx Context, req *QueryPayload) (*Result, error) { - return nil, nil +func (f *V2DataFetcher) FetchWorkload(reqContext Context, req *QueryPayload) (*Result, error) { + // Query the resource service for the workload by name and tenancy + resourceReq := pbresource.ReadRequest{ + Id: &pbresource.ID{ + Name: req.Name, + Type: pbcatalog.WorkloadType, + Tenancy: queryTenancyToResourceTenancy(req.Tenancy), + }, + } + + f.logger.Debug("fetching workload", "name", req.Name) + resourceCtx := metadata.AppendToOutgoingContext(context.Background(), "x-consul-token", reqContext.Token) + + // If the workload is not found, return nil and an error equivalent to NXDOMAIN + response, err := f.client.Read(resourceCtx, &resourceReq) + switch { + case grpcNotFoundErr(err): + f.logger.Debug("workload not found", "name", req.Name) + return nil, ErrNotFound + case err != nil: + f.logger.Error("error fetching workload", "name", req.Name) + return nil, fmt.Errorf("error fetching workload: %w", err) + // default: fallthrough + } + + workload := &pbcatalog.Workload{} + data := response.GetResource().GetData() + if err := data.UnmarshalTo(workload); err != nil { + f.logger.Error("error unmarshalling workload", "name", req.Name) + return nil, fmt.Errorf("error unmarshalling workload: %w", err) + } + + // TODO: (v2-dns): we will need to intelligently return the right workload address based on either the translate + // address setting or the locality of the requester. Workloads must have at least one. + // We also need to make sure that we filter out unix sockets here. + address := workload.Addresses[0].GetHost() + if strings.HasPrefix(address, "unix://") { + f.logger.Error("unix sockets are currently unsupported in workload results", "name", req.Name) + return nil, ErrNotFound + } + + tenancy := response.GetResource().GetId().GetTenancy() + result := &Result{ + Address: address, + Type: ResultTypeWorkload, + Tenancy: ResultTenancy{ + Namespace: tenancy.GetNamespace(), + Partition: tenancy.GetPartition(), + }, + Target: response.GetResource().GetId().GetName(), + } + + if req.PortName == "" { + return result, nil + } + + // If a port is specified, make sure the workload implements that port name. + for name, port := range workload.Ports { + if name == req.PortName { + result.PortName = req.PortName + result.PortNumber = port.Port + return result, nil + } + } + + f.logger.Debug("could not find matching port for workload", "name", req.Name, "port", req.PortName) + // Return an ErrNotFound, which is equivalent to NXDOMAIN + return nil, ErrNotFound } // FetchPreparedQuery is used to fetch a prepared query from the V2 catalog. @@ -68,3 +156,45 @@ func (f *V2DataFetcher) FetchWorkload(ctx Context, req *QueryPayload) (*Result, func (f *V2DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error) { return nil, ErrNotSupported } + +func (f *V2DataFetcher) NormalizeRequest(req *QueryPayload) { + // If we do not have an explicit partition in the request, we use the agent's + if req.Tenancy.Partition == "" { + req.Tenancy.Partition = f.defaultPartition + } +} + +// ValidateRequest throws an error is any of the deprecated V1 input fields are used in a QueryByName for this data fetcher. +func (f *V2DataFetcher) ValidateRequest(_ Context, req *QueryPayload) error { + if req.Tag != "" { + return ErrNotSupported + } + if req.RemoteAddr != nil { + return ErrNotSupported + } + return nil +} + +func queryTenancyToResourceTenancy(qTenancy QueryTenancy) *pbresource.Tenancy { + rTenancy := resource.DefaultNamespacedTenancy() + + // If the request has any tenancy specified, it overrides the defaults. + if qTenancy.Namespace != "" { + rTenancy.Namespace = qTenancy.Namespace + } + // In the case of partition, we have the agent's partition as the fallback. + if qTenancy.Partition != "" { + rTenancy.Partition = qTenancy.Partition + } + + return rTenancy +} + +// grpcNotFoundErr returns true if the error is a gRPC status error with a code of NotFound. +func grpcNotFoundErr(err error) bool { + if err == nil { + return false + } + s, ok := status.FromError(err) + return ok && s.Code() == codes.NotFound +} diff --git a/agent/discovery/query_fetcher_v2_test.go b/agent/discovery/query_fetcher_v2_test.go new file mode 100644 index 0000000000..86e2af6b63 --- /dev/null +++ b/agent/discovery/query_fetcher_v2_test.go @@ -0,0 +1,259 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package discovery + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/anypb" + + "github.com/hashicorp/consul/agent/config" + mockpbresource "github.com/hashicorp/consul/grpcmocks/proto-public/pbresource" + "github.com/hashicorp/consul/internal/resource" + pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1" + "github.com/hashicorp/consul/proto-public/pbresource" + "github.com/hashicorp/consul/sdk/testutil" +) + +// Test_FetchService tests the FetchService method in scenarios where the RPC +// call succeeds and fails. +func Test_FetchWorkload(t *testing.T) { + + rc := &config.RuntimeConfig{ + DNSOnlyPassing: false, + } + + unknownErr := errors.New("I don't feel so good") + + tests := []struct { + name string + queryPayload *QueryPayload + context Context + configureMockClient func(mockClient *mockpbresource.ResourceServiceClient_Expecter) + expectedResult *Result + expectedErr error + }{ + { + name: "FetchWorkload returns result", + queryPayload: &QueryPayload{ + Name: "foo-1234", + }, + context: Context{ + Token: "test-token", + }, + configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) { + result := getTestWorkloadResponse(t, "", "") + mockClient.Read(mock.Anything, mock.Anything). + Return(result, nil). + Once(). + Run(func(args mock.Arguments) { + req := args.Get(1).(*pbresource.ReadRequest) + require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name) + }) + }, + expectedResult: &Result{ + Address: "1.2.3.4", + Type: ResultTypeWorkload, + Tenancy: ResultTenancy{ + Namespace: resource.DefaultNamespaceName, + Partition: resource.DefaultPartitionName, + }, + Target: "foo-1234", + }, + expectedErr: nil, + }, + { + name: "FetchWorkload for non-existent workload", + queryPayload: &QueryPayload{ + Name: "foo-1234", + }, + context: Context{ + Token: "test-token", + }, + configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) { + input := getTestWorkloadResponse(t, "", "") + mockClient.Read(mock.Anything, mock.Anything). + Return(nil, status.Error(codes.NotFound, "not found")). + Once(). + Run(func(args mock.Arguments) { + req := args.Get(1).(*pbresource.ReadRequest) + require.Equal(t, input.GetResource().GetId().GetName(), req.Id.Name) + }) + }, + expectedResult: nil, + expectedErr: ErrNotFound, + }, + { + name: "FetchWorkload encounters a resource client error", + queryPayload: &QueryPayload{ + Name: "foo-1234", + }, + context: Context{ + Token: "test-token", + }, + configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) { + input := getTestWorkloadResponse(t, "", "") + mockClient.Read(mock.Anything, mock.Anything). + Return(nil, unknownErr). + Once(). + Run(func(args mock.Arguments) { + req := args.Get(1).(*pbresource.ReadRequest) + require.Equal(t, input.GetResource().GetId().GetName(), req.Id.Name) + }) + }, + expectedResult: nil, + expectedErr: unknownErr, + }, + { + name: "FetchWorkload with a matching port", + queryPayload: &QueryPayload{ + Name: "foo-1234", + PortName: "api", + }, + context: Context{ + Token: "test-token", + }, + configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) { + result := getTestWorkloadResponse(t, "", "") + mockClient.Read(mock.Anything, mock.Anything). + Return(result, nil). + Once(). + Run(func(args mock.Arguments) { + req := args.Get(1).(*pbresource.ReadRequest) + require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name) + }) + }, + expectedResult: &Result{ + Address: "1.2.3.4", + Type: ResultTypeWorkload, + PortName: "api", + PortNumber: 5678, + Tenancy: ResultTenancy{ + Namespace: resource.DefaultNamespaceName, + Partition: resource.DefaultPartitionName, + }, + Target: "foo-1234", + }, + expectedErr: nil, + }, + { + name: "FetchWorkload with a matching port", + queryPayload: &QueryPayload{ + Name: "foo-1234", + PortName: "not-api", + }, + context: Context{ + Token: "test-token", + }, + configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) { + result := getTestWorkloadResponse(t, "", "") + mockClient.Read(mock.Anything, mock.Anything). + Return(result, nil). + Once(). + Run(func(args mock.Arguments) { + req := args.Get(1).(*pbresource.ReadRequest) + require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name) + }) + }, + expectedResult: nil, + expectedErr: ErrNotFound, + }, + { + name: "FetchWorkload returns result for non-default tenancy", + queryPayload: &QueryPayload{ + Name: "foo-1234", + Tenancy: QueryTenancy{ + Namespace: "test-namespace", + Partition: "test-partition", + }, + }, + context: Context{ + Token: "test-token", + }, + configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) { + result := getTestWorkloadResponse(t, "test-namespace", "test-partition") + mockClient.Read(mock.Anything, mock.Anything). + Return(result, nil). + Once(). + Run(func(args mock.Arguments) { + req := args.Get(1).(*pbresource.ReadRequest) + require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name) + require.Equal(t, result.GetResource().GetId().GetTenancy().GetNamespace(), req.Id.Tenancy.Namespace) + require.Equal(t, result.GetResource().GetId().GetTenancy().GetPartition(), req.Id.Tenancy.Partition) + }) + }, + expectedResult: &Result{ + Address: "1.2.3.4", + Type: ResultTypeWorkload, + Tenancy: ResultTenancy{ + Namespace: "test-namespace", + Partition: "test-partition", + }, + Target: "foo-1234", + }, + expectedErr: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + logger := testutil.Logger(t) + + client := mockpbresource.NewResourceServiceClient(t) + mockClient := client.EXPECT() + tc.configureMockClient(mockClient) + + df := NewV2DataFetcher(rc, client, logger) + + result, err := df.FetchWorkload(tc.context, tc.queryPayload) + require.True(t, errors.Is(err, tc.expectedErr)) + require.Equal(t, tc.expectedResult, result) + }) + } +} + +func getTestWorkloadResponse(t *testing.T, nsOverride string, partitionOverride string) *pbresource.ReadResponse { + workload := &pbcatalog.Workload{ + Addresses: []*pbcatalog.WorkloadAddress{ + { + Host: "1.2.3.4", + Ports: []string{"api"}, + }, + }, + Ports: map[string]*pbcatalog.WorkloadPort{ + "api": { + Port: 5678, + }, + }, + Identity: "test-identity", + } + + data, err := anypb.New(workload) + require.NoError(t, err) + + resp := &pbresource.ReadResponse{ + Resource: &pbresource.Resource{ + Id: &pbresource.ID{ + Name: "foo-1234", + Type: pbcatalog.WorkloadType, + Tenancy: resource.DefaultNamespacedTenancy(), // TODO (v2-dns): tenancy + }, + Data: data, + }, + } + + if nsOverride != "" { + resp.Resource.Id.Tenancy.Namespace = nsOverride + } + if partitionOverride != "" { + resp.Resource.Id.Tenancy.Partition = partitionOverride + } + + return resp +} diff --git a/agent/discovery/query_locality.go b/agent/discovery/query_locality.go deleted file mode 100644 index 55b77352e9..0000000000 --- a/agent/discovery/query_locality.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) HashiCorp, Inc. -// SPDX-License-Identifier: BUSL-1.1 - -package discovery - -import "github.com/hashicorp/consul/acl" - -// QueryLocality is the locality parsed from a DNS query. -type QueryLocality struct { - // Datacenter is the datacenter parsed from a label that has an explicit datacenter part. - // Example query: .virtual..ns..ap..dc.consul - Datacenter string - - // Peer is the peer name parsed from a label that has explicit parts. - // Example query: .virtual..ns..peer..ap.consul - Peer string - - // PeerOrDatacenter is parsed from DNS queries where the datacenter and peer name are - // specified in the same query part. - // Example query: .virtual..consul - // - // Note that this field should only be a "peer" for virtual queries, since virtual IPs should - // not be shared between datacenters. In all other cases, it should be considered a DC. - PeerOrDatacenter string - - acl.EnterpriseMeta -} - -// EffectiveDatacenter returns the datacenter parsed from a query, or a default -// value if none is specified. -func (l QueryLocality) EffectiveDatacenter(defaultDC string) string { - // Prefer the value parsed from a query with explicit parts: .ns..ap..dc - if l.Datacenter != "" { - return l.Datacenter - } - // Fall back to the ambiguously parsed DC or Peer. - if l.PeerOrDatacenter != "" { - return l.PeerOrDatacenter - } - // If all are empty, use a default value. - return defaultDC -} - -// GetQueryTenancyBasedOnLocality returns a discovery.QueryTenancy from a DNS message. -func GetQueryTenancyBasedOnLocality(locality QueryLocality, defaultDatacenter string) (QueryTenancy, error) { - datacenter := locality.EffectiveDatacenter(defaultDatacenter) - // Only one of dc or peer can be used. - if locality.Peer != "" { - datacenter = "" - } - - return QueryTenancy{ - EnterpriseMeta: locality.EnterpriseMeta, - // The datacenter of the request is not specified because cross-datacenter virtual IP - // queries are not supported. This guard rail is in place because virtual IPs are allocated - // within a DC, therefore their uniqueness is not guaranteed globally. - Peer: locality.Peer, - Datacenter: datacenter, - SamenessGroup: "", // this should be nil since the single locality was directly used to configure tenancy. - }, nil -} diff --git a/agent/discovery/query_locality_ce.go b/agent/discovery/query_locality_ce.go deleted file mode 100644 index 4cc4f312d4..0000000000 --- a/agent/discovery/query_locality_ce.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright (c) HashiCorp, Inc. -// SPDX-License-Identifier: BUSL-1.1 - -//go:build !consulent - -package discovery - -import ( - "github.com/hashicorp/consul/acl" - "github.com/hashicorp/consul/agent/config" -) - -// ParseLocality can parse peer name or datacenter from a DNS query's labels. -// Peer name is parsed from the same query part that datacenter is, so given this ambiguity -// we parse a "peerOrDatacenter". The caller or RPC handler are responsible for disambiguating. -func ParseLocality(labels []string, defaultEnterpriseMeta acl.EnterpriseMeta, _ EnterpriseDNSConfig) (QueryLocality, bool) { - locality := QueryLocality{ - EnterpriseMeta: defaultEnterpriseMeta, - } - - switch len(labels) { - case 2, 4: - // Support the following formats: - // - [..dc] - // - [..peer] - for i := 0; i < len(labels); i += 2 { - switch labels[i+1] { - case "dc": - locality.Datacenter = labels[i] - case "peer": - locality.Peer = labels[i] - default: - return QueryLocality{}, false - } - } - // Return error when both datacenter and peer are specified. - if locality.Datacenter != "" && locality.Peer != "" { - return QueryLocality{}, false - } - return locality, true - case 1: - return QueryLocality{PeerOrDatacenter: labels[0]}, true - - case 0: - return QueryLocality{}, true - } - - return QueryLocality{}, false -} - -// EnterpriseDNSConfig is the configuration for enterprise DNS. -type EnterpriseDNSConfig struct{} - -// GetEnterpriseDNSConfig returns the enterprise DNS configuration. -func GetEnterpriseDNSConfig(conf *config.RuntimeConfig) EnterpriseDNSConfig { - return EnterpriseDNSConfig{} -} diff --git a/agent/discovery/query_locality_ce_test.go b/agent/discovery/query_locality_ce_test.go deleted file mode 100644 index 5f720c2121..0000000000 --- a/agent/discovery/query_locality_ce_test.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) HashiCorp, Inc. -// SPDX-License-Identifier: BUSL-1.1 - -//go:build !consulent - -package discovery - -import ( - "github.com/hashicorp/consul/acl" -) - -func getTestCases() []testCaseParseLocality { - testCases := []testCaseParseLocality{ - { - name: "test [..dc]", - labels: []string{"test-dc", "dc"}, - enterpriseDNSConfig: EnterpriseDNSConfig{}, - expectedResult: QueryLocality{ - EnterpriseMeta: acl.EnterpriseMeta{}, - Datacenter: "test-dc", - }, - expectedOK: true, - }, - { - name: "test [..peer]", - labels: []string{"test-peer", "peer"}, - enterpriseDNSConfig: EnterpriseDNSConfig{}, - expectedResult: QueryLocality{ - EnterpriseMeta: acl.EnterpriseMeta{}, - Peer: "test-peer", - }, - expectedOK: true, - }, - { - name: "test 1 label", - labels: []string{"test-peer"}, - enterpriseDNSConfig: EnterpriseDNSConfig{}, - expectedResult: QueryLocality{ - EnterpriseMeta: acl.EnterpriseMeta{}, - PeerOrDatacenter: "test-peer", - }, - expectedOK: true, - }, - { - name: "test 0 labels", - labels: []string{}, - enterpriseDNSConfig: EnterpriseDNSConfig{}, - expectedResult: QueryLocality{}, - expectedOK: true, - }, - { - name: "test 3 labels returns not found", - labels: []string{"test-dc", "dc", "test-blah"}, - enterpriseDNSConfig: EnterpriseDNSConfig{}, - expectedResult: QueryLocality{}, - expectedOK: false, - }, - } - return testCases -} diff --git a/agent/discovery/query_locality_test.go b/agent/discovery/query_locality_test.go deleted file mode 100644 index 2c1ce28c9d..0000000000 --- a/agent/discovery/query_locality_test.go +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) HashiCorp, Inc. -// SPDX-License-Identifier: BUSL-1.1 -package discovery - -import ( - "testing" - - "github.com/hashicorp/consul/acl" - "github.com/stretchr/testify/require" -) - -type testCaseParseLocality struct { - name string - labels []string - defaultMeta acl.EnterpriseMeta - enterpriseDNSConfig EnterpriseDNSConfig - expectedResult QueryLocality - expectedOK bool -} - -func Test_parseLocality(t *testing.T) { - testCases := getTestCases() - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - actualResult, actualOK := ParseLocality(tc.labels, tc.defaultMeta, tc.enterpriseDNSConfig) - require.Equal(t, tc.expectedOK, actualOK) - require.Equal(t, tc.expectedResult, actualResult) - - }) - } - -} - -func Test_effectiveDatacenter(t *testing.T) { - type testCase struct { - name string - QueryLocality QueryLocality - defaultDC string - expected string - } - testCases := []testCase{ - { - name: "return Datacenter first", - QueryLocality: QueryLocality{ - Datacenter: "test-dc", - PeerOrDatacenter: "test-peer", - }, - defaultDC: "default-dc", - expected: "test-dc", - }, - { - name: "return PeerOrDatacenter second", - QueryLocality: QueryLocality{ - PeerOrDatacenter: "test-peer", - }, - defaultDC: "default-dc", - expected: "test-peer", - }, - { - name: "return defaultDC as fallback", - QueryLocality: QueryLocality{}, - defaultDC: "default-dc", - expected: "default-dc", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - got := tc.QueryLocality.EffectiveDatacenter(tc.defaultDC) - require.Equal(t, tc.expected, got) - }) - } -} diff --git a/agent/dns/mock_DNSRouter.go b/agent/dns/mock_DNSRouter.go index a10f6cf185..788c894f58 100644 --- a/agent/dns/mock_DNSRouter.go +++ b/agent/dns/mock_DNSRouter.go @@ -4,8 +4,6 @@ package dns import ( config "github.com/hashicorp/consul/agent/config" - discovery "github.com/hashicorp/consul/agent/discovery" - miekgdns "github.com/miekg/dns" mock "github.com/stretchr/testify/mock" @@ -19,11 +17,11 @@ type MockDNSRouter struct { } // HandleRequest provides a mock function with given fields: req, reqCtx, remoteAddress -func (_m *MockDNSRouter) HandleRequest(req *miekgdns.Msg, reqCtx discovery.Context, remoteAddress net.Addr) *miekgdns.Msg { +func (_m *MockDNSRouter) HandleRequest(req *miekgdns.Msg, reqCtx Context, remoteAddress net.Addr) *miekgdns.Msg { ret := _m.Called(req, reqCtx, remoteAddress) var r0 *miekgdns.Msg - if rf, ok := ret.Get(0).(func(*miekgdns.Msg, discovery.Context, net.Addr) *miekgdns.Msg); ok { + if rf, ok := ret.Get(0).(func(*miekgdns.Msg, Context, net.Addr) *miekgdns.Msg); ok { r0 = rf(req, reqCtx, remoteAddress) } else { if ret.Get(0) != nil { diff --git a/agent/dns/parser.go b/agent/dns/parser.go new file mode 100644 index 0000000000..dd23e91591 --- /dev/null +++ b/agent/dns/parser.go @@ -0,0 +1,89 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package dns + +// parsedLabels defines valid DNS labels that are possible for ALL DNS query in Consul. (v1 and v2, CE and ENT) +// It is the job of the parser to populate the struct, the routers to call the query processor, +// and the query processor to validate is the labels. +type parsedLabels struct { + Datacenter string + Namespace string + Partition string + Peer string + PeerOrDatacenter string // deprecated: use Datacenter or Peer + SamenessGroup string +} + +// ParseLabels can parse a DNS query's labels and returns a parsedLabels. +// It also does light validation according to invariants across all possible DNS queries for all Consul versions +func parseLabels(labels []string) (*parsedLabels, bool) { + var result parsedLabels + + switch len(labels) { + case 2, 4, 6: + // Supports the following formats: + // - [..ns][..ap][..dc] + // - . + // - [..ns][..ap][..peer] + // - [..sg][..ap][..ns] + for i := 0; i < len(labels); i += 2 { + switch labels[i+1] { + case "ns": + result.Namespace = labels[i] + case "ap": + result.Partition = labels[i] + case "dc": // TODO (v2-dns): This should also include "cluster" for the new notation. + result.Datacenter = labels[i] + case "sg": + result.SamenessGroup = labels[i] + case "peer": + result.Peer = labels[i] + default: + // The only case in which labels[i+1] is allowed to be a value + // other than ns, ap, or dc is if n == 2 to support the format: + // .. + if len(labels) == 2 { + result.PeerOrDatacenter = labels[1] + result.Namespace = labels[0] + return &result, true + } + return nil, false + } + } + + // VALIDATIONS + // Return nil result and false boolean when both datacenter and peer are specified. + if result.Datacenter != "" && result.Peer != "" { + return nil, false + } + + // Validation e need to validate that this a valid DNS including sg + if result.SamenessGroup != "" && (result.Datacenter != "" || result.Peer != "") { + return nil, false + } + + return &result, true + + case 1: + result.PeerOrDatacenter = labels[0] + return &result, true + + case 0: + return &result, true + } + + return &result, false +} + +// parsePort looks through the query parts for a named port label. +// It assumes the only valid input format is["", "port", ""]. +// The other expected formats are [""] and ["", ""]. +// It is expected that the queryProcessor validates if the label is allowed for the query type. +func parsePort(parts []string) string { + // The minimum number of parts would be + if len(parts) != 3 || parts[1] != "port" { + return "" + } + return parts[0] +} diff --git a/agent/dns/parser_test.go b/agent/dns/parser_test.go new file mode 100644 index 0000000000..a18dd9b184 --- /dev/null +++ b/agent/dns/parser_test.go @@ -0,0 +1,6 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package dns + +// TODO (v2-dns): parser tests diff --git a/agent/dns/router.go b/agent/dns/router.go index 06b1caa026..94732ee5de 100644 --- a/agent/dns/router.go +++ b/agent/dns/router.go @@ -14,14 +14,15 @@ import ( "time" "github.com/armon/go-radix" - "github.com/hashicorp/go-hclog" "github.com/miekg/dns" - "github.com/hashicorp/consul/acl" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/discovery" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/internal/dnsutil" + "github.com/hashicorp/consul/internal/resource" "github.com/hashicorp/consul/logging" ) @@ -40,6 +41,7 @@ const ( var ( errInvalidQuestion = fmt.Errorf("invalid question") errNameNotFound = fmt.Errorf("name not found") + errNotImplemented = fmt.Errorf("not implemented") errRecursionFailed = fmt.Errorf("recursion failed") trailingSpacesRE = regexp.MustCompile(" +$") @@ -47,6 +49,12 @@ var ( // TODO (v2-dns): metrics +// Context is used augment a DNS message with Consul-specific metadata. +type Context struct { + Token string + DefaultPartition string +} + // RouterDynamicConfig is the dynamic configuration that can be hot-reloaded type RouterDynamicConfig struct { ARecordLimit int @@ -64,26 +72,6 @@ type RouterDynamicConfig struct { // TTLStrict sets TTLs to service by full name match. It Has higher priority than TTLRadix TTLStrict map[string]time.Duration UDPAnswerLimit int - - discovery.EnterpriseDNSConfig -} - -// GetTTLForService Find the TTL for a given service. -// return ttl, true if found, 0, false otherwise -func (cfg *RouterDynamicConfig) GetTTLForService(service string) (time.Duration, bool) { - if cfg.TTLStrict != nil { - ttl, ok := cfg.TTLStrict[service] - if ok { - return ttl, true - } - } - if cfg.TTLRadix != nil { - _, ttlRaw, ok := cfg.TTLRadix.LongestPrefix(service) - if ok { - return ttlRaw.(time.Duration), true - } - } - return 0, false } type SOAConfig struct { @@ -120,10 +108,6 @@ type Router struct { tokenFunc func() string - defaultEntMeta acl.EnterpriseMeta - - // TODO (v2-dns): default locality for request context? - // dynamicConfig stores the config as an atomic value (for hot-reloading). // It is always of type *RouterDynamicConfig dynamicConfig atomic.Value @@ -142,13 +126,12 @@ func NewRouter(cfg Config) (*Router, error) { logger := cfg.Logger.Named(logging.DNS) router := &Router{ - processor: cfg.Processor, - recursor: newRecursor(logger), - domain: domain, - altDomain: altDomain, - logger: logger, - tokenFunc: cfg.TokenFunc, - defaultEntMeta: cfg.EntMeta, + processor: cfg.Processor, + recursor: newRecursor(logger), + domain: domain, + altDomain: altDomain, + logger: logger, + tokenFunc: cfg.TokenFunc, } if err := router.ReloadConfig(cfg.AgentConfig); err != nil { @@ -158,13 +141,13 @@ func NewRouter(cfg Config) (*Router, error) { } // HandleRequest is used to process an individual DNS request. It returns a message in success or fail cases. -func (r *Router) HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAddress net.Addr) *dns.Msg { +func (r *Router) HandleRequest(req *dns.Msg, reqCtx Context, remoteAddress net.Addr) *dns.Msg { return r.handleRequestRecursively(req, reqCtx, remoteAddress, maxRecursionLevelDefault) } // handleRequestRecursively is used to process an individual DNS request. It will recurse as needed // a maximum number of times and returns a message in success or fail cases. -func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx discovery.Context, +func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx Context, remoteAddress net.Addr, maxRecursionLevel int) *dns.Msg { configCtx := r.dynamicConfig.Load().(*RouterDynamicConfig) @@ -204,14 +187,28 @@ func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx discovery.Context } reqType := parseRequestType(req) - results, query, err := r.getQueryResults(req, reqCtx, reqType, configCtx, qName) + results, query, err := r.getQueryResults(req, reqCtx, reqType, qName) switch { case errors.Is(err, errNameNotFound): r.logger.Error("name not found", "name", qName) ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal) return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal) - // TODO (v2-dns): there is another case here where the discovery service returns "name not found" + case errors.Is(err, errNotImplemented): + r.logger.Error("query not implemented", "name", qName, "type", dns.Type(req.Question[0].Qtype).String()) + + ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal) + return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNotImplemented, ecsGlobal) + case errors.Is(err, discovery.ErrNotSupported): + r.logger.Debug("query name syntax not supported", "name", req.Question[0].Name) + + ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal) + return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal) + case errors.Is(err, discovery.ErrNotFound): + r.logger.Debug("query name not found", "name", req.Question[0].Name) + + ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal) + return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal) case errors.Is(err, discovery.ErrNoData): r.logger.Debug("no data available", "name", qName) @@ -249,13 +246,22 @@ func (r *Router) trimDomain(questionName string) string { // getTTLForResult returns the TTL for a given result. func getTTLForResult(name string, query *discovery.Query, cfg *RouterDynamicConfig) uint32 { - switch { + // In the case we are not making a discovery query, such as addr. or arpa. lookups, + // use the node TTL by convention + if query == nil { + return uint32(cfg.NodeTTL / time.Second) + } + + switch query.QueryType { // TODO (v2-dns): currently have to do this related to the results type being changed to node whe // the v1 data fetcher encounters a blank service address and uses the node address instead. // we will revisiting this when look at modifying the discovery result struct to // possibly include additional metadata like the node address. - case query != nil && query.QueryType == discovery.QueryTypeService: - ttl, ok := cfg.GetTTLForService(name) + case discovery.QueryTypeWorkload: + // TODO (v2-dns): we need to discuss what we want to do for workload TTLs + return 0 + case discovery.QueryTypeService: + ttl, ok := cfg.getTTLForService(name) if ok { return uint32(ttl / time.Second) } @@ -266,9 +272,7 @@ func getTTLForResult(name string, query *discovery.Query, cfg *RouterDynamicConf } // getQueryResults returns a discovery.Result from a DNS message. -func (r *Router) getQueryResults(req *dns.Msg, reqCtx discovery.Context, - reqType requestType, cfg *RouterDynamicConfig, qName string) ([]*discovery.Result, *discovery.Query, error) { - var query *discovery.Query +func (r *Router) getQueryResults(req *dns.Msg, reqCtx Context, reqType requestType, qName string) ([]*discovery.Result, *discovery.Query, error) { switch reqType { case requestTypeConsul: // This is a special case of discovery.QueryByName where we know that we need to query the consul service @@ -277,19 +281,26 @@ func (r *Router) getQueryResults(req *dns.Msg, reqCtx discovery.Context, QueryType: discovery.QueryTypeService, QueryPayload: discovery.QueryPayload{ Name: structs.ConsulServiceName, + Tenancy: discovery.QueryTenancy{ + // We specify the partition here so that in the case we are a client agent in a non-default partition. + // We don't want the query processors default partition to be used. + // This is a small hack because for V1 CE, this is not the correct default partition name, but we + // need to add something to disambiguate the empty field. + Partition: resource.DefaultPartitionName, + }, }, Limit: 3, // TODO (v2-dns): need to thread this through to the backend and make sure we shuffle the results } - results, err := r.processor.QueryByName(query, reqCtx) + results, err := r.processor.QueryByName(query, discovery.Context{Token: reqCtx.Token}) return results, query, err case requestTypeName: - query, err := buildQueryFromDNSMessage(req, r.domain, r.altDomain, cfg, r.defaultEntMeta, r.datacenter) + query, err := buildQueryFromDNSMessage(req, reqCtx, r.domain, r.altDomain) if err != nil { r.logger.Error("error building discovery query from DNS request", "error", err) return nil, query, err } - results, err := r.processor.QueryByName(query, reqCtx) + results, err := r.processor.QueryByName(query, discovery.Context{Token: reqCtx.Token}) if err != nil { r.logger.Error("error processing discovery query", "error", err) return nil, query, err @@ -301,17 +312,17 @@ func (r *Router) getQueryResults(req *dns.Msg, reqCtx discovery.Context, r.logger.Error("error building IP from DNS request", "name", qName) return nil, nil, errNameNotFound } - results, err := r.processor.QueryByIP(ip, reqCtx) - return results, query, err + results, err := r.processor.QueryByIP(ip, discovery.Context{Token: reqCtx.Token}) + return results, nil, err case requestTypeAddress: results, err := buildAddressResults(req) if err != nil { r.logger.Error("error processing discovery query", "error", err) - return nil, query, err + return nil, nil, err } - return results, query, nil + return results, nil, nil } - return nil, query, errors.New("invalid request type") + return nil, nil, errors.New("invalid request type") } // ServeDNS implements the miekg/dns.Handler interface. @@ -332,6 +343,24 @@ func (r *Router) ReloadConfig(newCfg *config.RuntimeConfig) error { return nil } +// getTTLForService Find the TTL for a given service. +// return ttl, true if found, 0, false otherwise +func (cfg *RouterDynamicConfig) getTTLForService(service string) (time.Duration, bool) { + if cfg.TTLStrict != nil { + ttl, ok := cfg.TTLStrict[service] + if ok { + return ttl, true + } + } + if cfg.TTLRadix != nil { + _, ttlRaw, ok := cfg.TTLRadix.LongestPrefix(service) + if ok { + return ttlRaw.(time.Duration), true + } + } + return 0, false +} + // Request type is similar to miekg/dns.Type, but correlates to the different query processors we might need to invoke. type requestType string @@ -391,7 +420,7 @@ func parseRequestType(req *dns.Msg) requestType { } // serializeQueryResults converts a discovery.Result into a DNS message. -func (r *Router) serializeQueryResults(req *dns.Msg, reqCtx discovery.Context, +func (r *Router) serializeQueryResults(req *dns.Msg, reqCtx Context, query *discovery.Query, results []*discovery.Result, cfg *RouterDynamicConfig, responseDomain string, remoteAddress net.Addr, maxRecursionLevel int) (*dns.Msg, error) { resp := new(dns.Msg) @@ -430,7 +459,7 @@ func (r *Router) serializeQueryResults(req *dns.Msg, reqCtx discovery.Context, // appendResultsToDNSResponse builds dns message from the discovery results and // appends them to the dns response. -func (r *Router) appendResultsToDNSResponse(req *dns.Msg, reqCtx discovery.Context, +func (r *Router) appendResultsToDNSResponse(req *dns.Msg, reqCtx Context, query *discovery.Query, resp *dns.Msg, results []*discovery.Result, cfg *RouterDynamicConfig, responseDomain string, remoteAddress net.Addr, maxRecursionLevel int) { @@ -487,19 +516,19 @@ func (r *Router) appendResultsToDNSResponse(req *dns.Msg, reqCtx discovery.Conte } // defaultAgentDNSRequestContext returns a default request context based on the agent's config. -func (r *Router) defaultAgentDNSRequestContext() discovery.Context { - return discovery.Context{ +func (r *Router) defaultAgentDNSRequestContext() Context { + return Context{ Token: r.tokenFunc(), - // TODO (v2-dns): tenancy information; maybe we choose not to specify and use the default - // attached to the Router (from the agent's config) + // We don't need to specify the agent's partition here because that will be handled further down the stack + // in the query processor. } } // resolveCNAME is used to recursively resolve CNAME records -func (r *Router) resolveCNAME(cfg *RouterDynamicConfig, name string, reqCtx discovery.Context, +func (r *Router) resolveCNAME(cfg *RouterDynamicConfig, name string, reqCtx Context, remoteAddress net.Addr, maxRecursionLevel int) []dns.RR { // If the CNAME record points to a Consul address, resolve it internally - // Convert query to lowercase because DNS is case insensitive; d.domain and + // Convert query to lowercase because DNS is case-insensitive; d.domain and // d.altDomain are already converted if ln := strings.ToLower(name); strings.HasSuffix(ln, "."+r.domain) || strings.HasSuffix(ln, "."+r.altDomain) { @@ -609,7 +638,6 @@ func getDynamicRouterConfig(conf *config.RuntimeConfig) (*RouterDynamicConfig, e Refresh: conf.DNSSOA.Refresh, Retry: conf.DNSSOA.Retry, }, - EnterpriseDNSConfig: discovery.GetEnterpriseDNSConfig(conf), } if conf.DNSServiceTTL != nil { @@ -765,7 +793,7 @@ func buildAddressResults(req *dns.Msg) ([]*discovery.Result, error) { } // getAnswerAndExtra creates the dns answer and extra from discovery results. -func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, reqCtx discovery.Context, +func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, reqCtx Context, query *discovery.Query, cfg *RouterDynamicConfig, domain string, remoteAddress net.Addr, maxRecursionLevel int) (answer []dns.RR, extra []dns.RR, ns []dns.RR) { target := newDNSAddress(result.Target) @@ -829,7 +857,7 @@ func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, req // getAnswerExtrasForAddressAndTarget creates the dns answer and extra from address and target dnsAddress pairs. func (r *Router) getAnswerExtrasForAddressAndTarget(address *dnsAddress, target *dnsAddress, req *dns.Msg, - reqCtx discovery.Context, result *discovery.Result, ttl uint32, remoteAddress net.Addr, + reqCtx Context, result *discovery.Result, ttl uint32, remoteAddress net.Addr, cfg *RouterDynamicConfig, maxRecursionLevel int) (answer []dns.RR, extra []dns.RR) { qName := req.Question[0].Name reqType := parseRequestType(req) @@ -854,7 +882,7 @@ func (r *Router) getAnswerExtrasForAddressAndTarget(address *dnsAddress, target // Target is FQDN that point to IP case target.IsFQDN() && address.IsIP(): var a, e []dns.RR - if result.Type == discovery.ResultTypeNode { + if result.Type == discovery.ResultTypeNode || result.Type == discovery.ResultTypeWorkload { // if it is a node record it means the service address pointed to a node // and the node address was used. So we create an A record for the node address, // as well as a CNAME for the service to node mapping. @@ -977,7 +1005,7 @@ func makeIPBasedRecord(name string, addr *dnsAddress, ttl uint32) dns.RR { } func (r *Router) makeRecordFromFQDN(fqdn string, result *discovery.Result, - req *dns.Msg, reqCtx discovery.Context, cfg *RouterDynamicConfig, ttl uint32, + req *dns.Msg, reqCtx Context, cfg *RouterDynamicConfig, ttl uint32, remoteAddress net.Addr, maxRecursionLevel int) ([]dns.RR, []dns.RR) { edns := req.IsEdns0() != nil q := req.Question[0] @@ -1039,7 +1067,7 @@ func makeSRVRecord(name, target string, result *discovery.Result, ttl uint32) *d }, Priority: 1, Weight: uint16(result.Weight), - Port: uint16(result.Port), + Port: uint16(result.PortNumber), Target: target, } } diff --git a/agent/dns/router_ce.go b/agent/dns/router_ce.go index 5ffc9f51ce..3a44ca1cdc 100644 --- a/agent/dns/router_ce.go +++ b/agent/dns/router_ce.go @@ -28,6 +28,9 @@ func canonicalNameForResult(result *discovery.Result, domain string) string { // Return a simpler format for non-peering nodes. return fmt.Sprintf("%s.node.%s.%s", result.Target, result.Tenancy.Datacenter, domain) case discovery.ResultTypeWorkload: + if result.PortName != "" { + return fmt.Sprintf("%s.port.%s.workload.%s", result.PortName, result.Target, domain) + } return fmt.Sprintf("%s.workload.%s", result.Target, domain) } return "" diff --git a/agent/dns/router_query.go b/agent/dns/router_query.go index 847dc45c5e..13a4935be0 100644 --- a/agent/dns/router_query.go +++ b/agent/dns/router_query.go @@ -4,37 +4,39 @@ package dns import ( - "errors" "strings" "github.com/miekg/dns" - "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/discovery" ) // buildQueryFromDNSMessage returns a discovery.Query from a DNS message. -func buildQueryFromDNSMessage(req *dns.Msg, domain, altDomain string, - cfg *RouterDynamicConfig, defaultEntMeta acl.EnterpriseMeta, defaultDatacenter string) (*discovery.Query, error) { +func buildQueryFromDNSMessage(req *dns.Msg, reqCtx Context, domain, altDomain string) (*discovery.Query, error) { queryType, queryParts, querySuffixes := getQueryTypePartsAndSuffixesFromDNSMessage(req, domain, altDomain) - queryTenancy, err := getQueryTenancy(queryType, querySuffixes, defaultEntMeta, cfg, defaultDatacenter) + queryTenancy, err := getQueryTenancy(reqCtx, queryType, querySuffixes) if err != nil { return nil, err } name, tag := getQueryNameAndTagFromParts(queryType, queryParts) + portName := parsePort(queryParts) + + if queryType == discovery.QueryTypeWorkload && req.Question[0].Qtype == dns.TypeSRV { + // Currently we do not support SRV records for workloads + return nil, errNotImplemented + } + return &discovery.Query{ QueryType: queryType, QueryPayload: discovery.QueryPayload{ - Name: name, - Tenancy: queryTenancy, - Tag: tag, - // TODO (v2-dns): what should these be? - //PortName: "", - //RemoteAddr: nil, - //DisableFailover: false, + Name: name, + Tenancy: queryTenancy, + Tag: tag, + PortName: portName, + //RemoteAddr: nil, // TODO (v2-dns): Prepared Queries for V1 Catalog }, }, nil } @@ -64,30 +66,48 @@ func getQueryNameAndTagFromParts(queryType discovery.QueryType, queryParts []str } // getQueryTenancy returns a discovery.QueryTenancy from a DNS message. -func getQueryTenancy(queryType discovery.QueryType, querySuffixes []string, - defaultEntMeta acl.EnterpriseMeta, cfg *RouterDynamicConfig, defaultDatacenter string) (discovery.QueryTenancy, error) { - if queryType == discovery.QueryTypeService { - return getQueryTenancyForService(querySuffixes, defaultEntMeta, cfg, defaultDatacenter) +func getQueryTenancy(reqCtx Context, queryType discovery.QueryType, querySuffixes []string) (discovery.QueryTenancy, error) { + labels, ok := parseLabels(querySuffixes) + if !ok { + return discovery.QueryTenancy{}, errNameNotFound } - locality, ok := discovery.ParseLocality(querySuffixes, defaultEntMeta, cfg.EnterpriseDNSConfig) - if !ok { - return discovery.QueryTenancy{}, errors.New("invalid locality") + // If we don't have an explicit partition in the request, try the first fallback + // which was supplied in the request context. The agent's partition will be used as the last fallback + // later in the query processor. + if labels.Partition == "" { + labels.Partition = reqCtx.DefaultPartition + } + + // If we have a sameness group, we can return early without further data massage. + if labels.SamenessGroup != "" { + return discovery.QueryTenancy{ + Namespace: labels.Namespace, + Partition: labels.Partition, + SamenessGroup: labels.SamenessGroup, + }, nil } if queryType == discovery.QueryTypeVirtual { - if locality.Peer == "" { + if labels.Peer == "" { // If the peer name was not explicitly defined, fall back to the ambiguously-parsed version. - locality.Peer = locality.PeerOrDatacenter + labels.Peer = labels.PeerOrDatacenter } } - return discovery.GetQueryTenancyBasedOnLocality(locality, defaultDatacenter) + return discovery.QueryTenancy{ + Namespace: labels.Namespace, + Partition: labels.Partition, + Peer: labels.Peer, + Datacenter: labels.Datacenter, + }, nil } // getQueryTypePartsAndSuffixesFromDNSMessage returns the query type, the parts, and suffixes of the query name. func getQueryTypePartsAndSuffixesFromDNSMessage(req *dns.Msg, domain, altDomain string) (queryType discovery.QueryType, parts []string, suffixes []string) { // Get the QName without the domain suffix + // TODO (v2-dns): we will also need to handle the "failover" and "no-failover" suffixes here. + // They come AFTER the domain. See `stripSuffix` in router.go qName := trimDomainFromQuestionName(req.Question[0].Name, domain, altDomain) // Split into the label parts @@ -97,7 +117,7 @@ func getQueryTypePartsAndSuffixesFromDNSMessage(req *dns.Msg, domain, altDomain for i := len(labels) - 1; i >= 0 && !done; i-- { queryType = getQueryTypeFromLabels(labels[i]) switch queryType { - case discovery.QueryTypeService, + case discovery.QueryTypeService, discovery.QueryTypeWorkload, discovery.QueryTypeConnect, discovery.QueryTypeVirtual, discovery.QueryTypeIngress, discovery.QueryTypeNode, discovery.QueryTypePreparedQuery: parts = labels[:i] @@ -122,7 +142,7 @@ func getQueryTypePartsAndSuffixesFromDNSMessage(req *dns.Msg, domain, altDomain // trimDomainFromQuestionName returns the question name without the domain suffix. func trimDomainFromQuestionName(questionName, domain, altDomain string) string { - qName := strings.ToLower(dns.Fqdn(questionName)) + qName := dns.CanonicalName(questionName) longer := domain shorter := altDomain diff --git a/agent/dns/router_query_ce.go b/agent/dns/router_query_ce.go deleted file mode 100644 index bbe868a2c8..0000000000 --- a/agent/dns/router_query_ce.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) HashiCorp, Inc. -// SPDX-License-Identifier: BUSL-1.1 - -//go:build !consulent - -package dns - -import ( - "errors" - - "github.com/hashicorp/consul/acl" - "github.com/hashicorp/consul/agent/discovery" -) - -// getQueryTenancy returns a discovery.QueryTenancy from a DNS message. -func getQueryTenancyForService(querySuffixes []string, - defaultEntMeta acl.EnterpriseMeta, cfg *RouterDynamicConfig, defaultDatacenter string) (discovery.QueryTenancy, error) { - locality, ok := discovery.ParseLocality(querySuffixes, defaultEntMeta, cfg.EnterpriseDNSConfig) - if !ok { - return discovery.QueryTenancy{}, errors.New("invalid locality") - } - - return discovery.GetQueryTenancyBasedOnLocality(locality, defaultDatacenter) -} diff --git a/agent/dns/router_query_ce_test.go b/agent/dns/router_query_ce_test.go deleted file mode 100644 index 13337dfbe0..0000000000 --- a/agent/dns/router_query_ce_test.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright (c) HashiCorp, Inc. -// SPDX-License-Identifier: BUSL-1.1 - -//go:build !consulent - -package dns - -import ( - "github.com/miekg/dns" - - "github.com/hashicorp/consul/acl" - "github.com/hashicorp/consul/agent/discovery" -) - -func getBuildQueryFromDNSMessageTestCases() []testCaseBuildQueryFromDNSMessage { - testCases := []testCaseBuildQueryFromDNSMessage{ - // virtual ip queries - { - name: "test A 'virtual.' query, ipv4 response", - request: &dns.Msg{ - MsgHdr: dns.MsgHdr{ - Opcode: dns.OpcodeQuery, - }, - Question: []dns.Question{ - { - Name: "db.virtual.consul", // "intentionally missing the trailing dot" - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }, - }, - }, - expectedQuery: &discovery.Query{ - QueryType: discovery.QueryTypeVirtual, - QueryPayload: discovery.QueryPayload{ - Name: "db", - PortName: "", - Tag: "", - Tenancy: discovery.QueryTenancy{ - EnterpriseMeta: acl.EnterpriseMeta{}, - SamenessGroup: "", - Peer: "consul", - Datacenter: "", - }, - DisableFailover: false, - }, - }, - }, - { - name: "test A 'virtual.' with peer query, ipv4 response", - request: &dns.Msg{ - MsgHdr: dns.MsgHdr{ - Opcode: dns.OpcodeQuery, - }, - Question: []dns.Question{ - { - Name: "db.virtual.consul", // "intentionally missing the trailing dot" - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }, - }, - }, - expectedQuery: &discovery.Query{ - QueryType: discovery.QueryTypeVirtual, - QueryPayload: discovery.QueryPayload{ - Name: "db", - PortName: "", - Tag: "", - Tenancy: discovery.QueryTenancy{ - EnterpriseMeta: acl.EnterpriseMeta{}, - SamenessGroup: "", - Peer: "consul", // this gets set in the query building after ParseLocality processes. - Datacenter: "", - }, - DisableFailover: false, - }, - }, - }, - } - - return testCases -} diff --git a/agent/dns/router_query_test.go b/agent/dns/router_query_test.go index 726ef32ba8..dc4ea6592e 100644 --- a/agent/dns/router_query_test.go +++ b/agent/dns/router_query_test.go @@ -7,29 +7,206 @@ import ( "testing" "github.com/miekg/dns" - - "github.com/hashicorp/consul/acl" - "github.com/hashicorp/consul/agent/discovery" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/hashicorp/consul/agent/discovery" ) // testCaseBuildQueryFromDNSMessage is a test case for the buildQueryFromDNSMessage function. type testCaseBuildQueryFromDNSMessage struct { name string request *dns.Msg - requestContext *discovery.Context + requestContext *Context expectedQuery *discovery.Query } // Test_buildQueryFromDNSMessage tests the buildQueryFromDNSMessage function. func Test_buildQueryFromDNSMessage(t *testing.T) { - testCases := getBuildQueryFromDNSMessageTestCases() + testCases := []testCaseBuildQueryFromDNSMessage{ + // virtual ip queries + { + name: "test A 'virtual.' query", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "db.virtual.consul", // "intentionally missing the trailing dot" + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + expectedQuery: &discovery.Query{ + QueryType: discovery.QueryTypeVirtual, + QueryPayload: discovery.QueryPayload{ + Name: "db", + Tenancy: discovery.QueryTenancy{}, + }, + }, + }, + { + name: "test A 'virtual.' with kitchen sink labels", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "db.virtual.banana.ns.orange.ap.foo.peer.consul", // "intentionally missing the trailing dot" + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + expectedQuery: &discovery.Query{ + QueryType: discovery.QueryTypeVirtual, + QueryPayload: discovery.QueryPayload{ + Name: "db", + Tenancy: discovery.QueryTenancy{ + Peer: "foo", + Namespace: "banana", + Partition: "orange", + }, + }, + }, + }, + { + name: "test A 'virtual.' with implicit peer", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "db.virtual.foo.consul", // "intentionally missing the trailing dot" + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + expectedQuery: &discovery.Query{ + QueryType: discovery.QueryTypeVirtual, + QueryPayload: discovery.QueryPayload{ + Name: "db", + Tenancy: discovery.QueryTenancy{ + Peer: "foo", + }, + }, + }, + }, + { + name: "test A 'virtual.' with implicit peer and namespace query", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "db.virtual.frontend.foo.consul", // "intentionally missing the trailing dot" + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + expectedQuery: &discovery.Query{ + QueryType: discovery.QueryTypeVirtual, + QueryPayload: discovery.QueryPayload{ + Name: "db", + Tenancy: discovery.QueryTenancy{ + Namespace: "frontend", + Peer: "foo", + }, + }, + }, + }, + { + name: "test A 'workload.'", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "foo.workload.consul", // "intentionally missing the trailing dot" + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + expectedQuery: &discovery.Query{ + QueryType: discovery.QueryTypeWorkload, + QueryPayload: discovery.QueryPayload{ + Name: "foo", + Tenancy: discovery.QueryTenancy{}, + }, + }, + }, + { + name: "test A 'workload.' with all possible labels", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "api.port.foo.workload.banana.ns.orange.ap.apple.peer.consul", // "intentionally missing the trailing dot" + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + expectedQuery: &discovery.Query{ + QueryType: discovery.QueryTypeWorkload, + QueryPayload: discovery.QueryPayload{ + Name: "foo", + PortName: "api", + Tenancy: discovery.QueryTenancy{ + Namespace: "banana", + Partition: "orange", + Peer: "apple", + }, + }, + }, + }, + { + name: "test sameness group with all possible labels", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "foo.service.apple.sg.banana.ns.orange.ap.consul", // "intentionally missing the trailing dot" + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + expectedQuery: &discovery.Query{ + QueryType: discovery.QueryTypeService, + QueryPayload: discovery.QueryPayload{ + Name: "foo", + Tenancy: discovery.QueryTenancy{ + Namespace: "banana", + Partition: "orange", + SamenessGroup: "apple", + }, + }, + }, + }, + } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - query, err := buildQueryFromDNSMessage(tc.request, "domain", "altDomain", &RouterDynamicConfig{}, acl.EnterpriseMeta{}, "defaultDatacenter") + context := tc.requestContext + if context == nil { + context = &Context{} + } + query, err := buildQueryFromDNSMessage(tc.request, *context, "consul.", ".") require.NoError(t, err) assert.Equal(t, tc.expectedQuery, query) }) diff --git a/agent/dns/router_test.go b/agent/dns/router_test.go index 33fcbdb72e..aa38d91ef3 100644 --- a/agent/dns/router_test.go +++ b/agent/dns/router_test.go @@ -9,11 +9,12 @@ import ( "testing" "time" - "github.com/hashicorp/go-hclog" "github.com/miekg/dns" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/discovery" @@ -29,15 +30,16 @@ import ( // 4. Test the edns settings. type HandleTestCase struct { - name string - agentConfig *config.RuntimeConfig // This will override the default test Router Config - configureDataFetcher func(fetcher discovery.CatalogDataFetcher) - configureRecursor func(recursor dnsRecursor) - mockProcessorError error - request *dns.Msg - requestContext *discovery.Context - remoteAddress net.Addr - response *dns.Msg + name string + agentConfig *config.RuntimeConfig // This will override the default test Router Config + configureDataFetcher func(fetcher discovery.CatalogDataFetcher) + validateAndNormalizeExpected bool + configureRecursor func(recursor dnsRecursor) + mockProcessorError error + request *dns.Msg + requestContext *Context + remoteAddress net.Addr + response *dns.Msg } func Test_HandleRequest(t *testing.T) { @@ -719,6 +721,7 @@ func Test_HandleRequest(t *testing.T) { Type: discovery.ResultTypeVirtual, }, nil) }, + validateAndNormalizeExpected: true, response: &dns.Msg{ MsgHdr: dns.MsgHdr{ Opcode: dns.OpcodeQuery, @@ -768,6 +771,7 @@ func Test_HandleRequest(t *testing.T) { Type: discovery.ResultTypeVirtual, }, nil) }, + validateAndNormalizeExpected: true, response: &dns.Msg{ MsgHdr: dns.MsgHdr{ Opcode: dns.OpcodeQuery, @@ -833,6 +837,7 @@ func Test_HandleRequest(t *testing.T) { require.Equal(t, structs.ConsulServiceName, req.Name) }) }, + validateAndNormalizeExpected: true, response: &dns.Msg{ MsgHdr: dns.MsgHdr{ Opcode: dns.OpcodeQuery, @@ -954,6 +959,7 @@ func Test_HandleRequest(t *testing.T) { require.Equal(t, structs.ConsulServiceName, req.Name) }) }, + validateAndNormalizeExpected: true, response: &dns.Msg{ MsgHdr: dns.MsgHdr{ Opcode: dns.OpcodeQuery, @@ -1275,6 +1281,7 @@ func Test_HandleRequest(t *testing.T) { require.Equal(t, "foo", req.Name) }) }, + validateAndNormalizeExpected: true, response: &dns.Msg{ MsgHdr: dns.MsgHdr{ Opcode: dns.OpcodeQuery, @@ -1309,12 +1316,201 @@ func Test_HandleRequest(t *testing.T) { }, }, // TODO (v2-dns): add a test to make sure only 3 records are returned + // V2 Workload Lookup + { + name: "workload A query w/ port, returns A record", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "api.port.foo.workload.consul.", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { + result := &discovery.Result{ + Address: "1.2.3.4", + Type: discovery.ResultTypeWorkload, + Tenancy: discovery.ResultTenancy{}, + PortName: "api", + PortNumber: 5678, + Target: "foo", + } + + fetcher.(*discovery.MockCatalogDataFetcher). + On("FetchWorkload", mock.Anything, mock.Anything). + Return(result, nil). //TODO + Run(func(args mock.Arguments) { + req := args.Get(1).(*discovery.QueryPayload) + + require.Equal(t, "foo", req.Name) + require.Equal(t, "api", req.PortName) + }) + }, + validateAndNormalizeExpected: true, + response: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + Response: true, + Authoritative: true, + }, + Compress: true, + Question: []dns.Question{ + { + Name: "api.port.foo.workload.consul.", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + Answer: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{ + Name: "api.port.foo.workload.consul.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: net.ParseIP("1.2.3.4"), + }, + }, + }, + }, + { + name: "workload ANY query w/o port, returns A record", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "foo.workload.consul.", + Qtype: dns.TypeANY, + Qclass: dns.ClassINET, + }, + }, + }, + configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { + result := &discovery.Result{ + Address: "1.2.3.4", + Type: discovery.ResultTypeWorkload, + Tenancy: discovery.ResultTenancy{}, + Target: "foo", + } + + fetcher.(*discovery.MockCatalogDataFetcher). + On("FetchWorkload", mock.Anything, mock.Anything). + Return(result, nil). //TODO + Run(func(args mock.Arguments) { + req := args.Get(1).(*discovery.QueryPayload) + + require.Equal(t, "foo", req.Name) + require.Empty(t, req.PortName) + }) + }, + validateAndNormalizeExpected: true, + response: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + Response: true, + Authoritative: true, + }, + Compress: true, + Question: []dns.Question{ + { + Name: "foo.workload.consul.", + Qtype: dns.TypeANY, + Qclass: dns.ClassINET, + }, + }, + Answer: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{ + Name: "foo.workload.consul.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: net.ParseIP("1.2.3.4"), + }, + }, + }, + }, + { + name: "workload AAAA query with namespace, partition, and cluster id; returns A record", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "foo.workload.bar.ns.baz.ap.dc3.dc.consul.", + Qtype: dns.TypeAAAA, + Qclass: dns.ClassINET, + }, + }, + }, + configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { + result := &discovery.Result{ + Address: "1.2.3.4", + Type: discovery.ResultTypeWorkload, + Tenancy: discovery.ResultTenancy{ + Namespace: "bar", + Partition: "baz", + Datacenter: "dc3", + }, + Target: "foo", + } + + fetcher.(*discovery.MockCatalogDataFetcher). + On("FetchWorkload", mock.Anything, mock.Anything). + Return(result, nil). + Run(func(args mock.Arguments) { + req := args.Get(1).(*discovery.QueryPayload) + + require.Equal(t, "foo", req.Name) + require.Empty(t, req.PortName) + }) + }, + validateAndNormalizeExpected: true, + response: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + Response: true, + Authoritative: true, + }, + Compress: true, + Question: []dns.Question{ + { + Name: "foo.workload.bar.ns.baz.ap.dc3.dc.consul.", + Qtype: dns.TypeAAAA, + Qclass: dns.ClassINET, + }, + }, + Extra: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{ + Name: "foo.workload.bar.ns.baz.ap.dc3.dc.consul.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: net.ParseIP("1.2.3.4"), + }, + }, + }, + }, } - testCases = append(testCases, getAdditionalTestCases(t)...) + //testCases = append(testCases, getAdditionalTestCases(t)...) run := func(t *testing.T, tc HandleTestCase) { cdf := discovery.NewMockCatalogDataFetcher(t) + if tc.validateAndNormalizeExpected { + cdf.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil) + cdf.On("NormalizeRequest", mock.Anything).Return() + } + if tc.configureDataFetcher != nil { tc.configureDataFetcher(cdf) } @@ -1331,7 +1527,7 @@ func Test_HandleRequest(t *testing.T) { ctx := tc.requestContext if ctx == nil { - ctx = &discovery.Context{} + ctx = &Context{} } actual := router.HandleRequest(tc.request, *ctx, tc.remoteAddress) require.Equal(t, tc.response, actual) @@ -1391,7 +1587,7 @@ func TestRouterDynamicConfig_GetTTLForService(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - actual, ok := cfg.GetTTLForService(tc.inputKey) + actual, ok := cfg.getTTLForService(tc.inputKey) require.Equal(t, tc.shouldMatch, ok) require.Equal(t, tc.expectedDuration, actual) }) diff --git a/agent/dns/server.go b/agent/dns/server.go index 8620f7e83b..9508e34159 100644 --- a/agent/dns/server.go +++ b/agent/dns/server.go @@ -7,12 +7,12 @@ import ( "fmt" "net" - "github.com/hashicorp/go-hclog" "github.com/miekg/dns" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/config" - "github.com/hashicorp/consul/agent/discovery" "github.com/hashicorp/consul/logging" ) @@ -20,7 +20,7 @@ import ( // //go:generate mockery --name DNSRouter --inpackage type DNSRouter interface { - HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAddress net.Addr) *dns.Msg + HandleRequest(req *dns.Msg, reqCtx Context, remoteAddress net.Addr) *dns.Msg ServeDNS(w dns.ResponseWriter, req *dns.Msg) ReloadConfig(newCfg *config.RuntimeConfig) error } diff --git a/agent/grpc-external/services/dns/server_v2.go b/agent/grpc-external/services/dns/server_v2.go index 0748152e2f..5e04e02f6d 100644 --- a/agent/grpc-external/services/dns/server_v2.go +++ b/agent/grpc-external/services/dns/server_v2.go @@ -8,14 +8,14 @@ import ( "fmt" "net" - "github.com/hashicorp/go-hclog" "github.com/miekg/dns" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" - "github.com/hashicorp/consul/agent/discovery" + "github.com/hashicorp/go-hclog" + agentdns "github.com/hashicorp/consul/agent/dns" "github.com/hashicorp/consul/proto-public/pbdns" ) @@ -73,7 +73,7 @@ func (s *ServerV2) Query(ctx context.Context, req *pbdns.QueryRequest) (*pbdns.Q } // TODO (v2-dns): parse token and other context metadata from the grpc request/metadata - reqCtx := discovery.Context{ + reqCtx := agentdns.Context{ Token: s.TokenFunc(), }