diff --git a/agent/discovery/discovery.go b/agent/discovery/discovery.go index 0439ca20bc..4468eb892e 100644 --- a/agent/discovery/discovery.go +++ b/agent/discovery/discovery.go @@ -43,7 +43,6 @@ func (e ECSNotGlobalError) Unwrap() error { type Query struct { QueryType QueryType QueryPayload QueryPayload - Limit int } // QueryType is used to filter service endpoints. @@ -79,11 +78,12 @@ type QueryTenancy struct { // QueryPayload represents all information needed by the data backend // to decide which records to include. type QueryPayload struct { - Name string - PortName string // v1 - this could optionally be "connect" or "ingress"; v2 - this is the service port name - Tag string // deprecated: use for V1 only - RemoteAddr net.Addr // deprecated: used for prepared queries - Tenancy QueryTenancy // tenancy includes any additional labels specified before the domain + Name string + PortName string // v1 - this could optionally be "connect" or "ingress"; v2 - this is the service port name + Tag string // deprecated: use for V1 only + SourceIP net.IP // deprecated: used for prepared queries + Tenancy QueryTenancy // tenancy includes any additional labels specified before the domain + Limit int // The maximum number of records to return // v2 fields only EnableFailover bool @@ -104,19 +104,23 @@ const ( // It is the responsibility of the DNS encoder to know what to do with // each Result, based on the query type. type Result struct { - Address string // A/AAAA/CNAME records - could be used in the Extra section. CNAME is required to handle hostname addresses in workloads & nodes. + Service *Location // The name and address of the service. + Node *Location // The name and address of the node. Weight uint32 // SRV queries PortName string // Used to generate a fgdn when a specifc port was queried PortNumber uint32 // SRV queries Metadata map[string]string // Used to collect metadata into TXT Records Type ResultType // Used to reconstruct the fqdn name of the resource - // Used in SRV & PTR queries to point at an A/AAAA Record. - Target string - Tenancy ResultTenancy } +// Location is used to represent a service, node, or workload. +type Location struct { + Name string + Address string +} + // ResultTenancy is used to reconstruct the fqdn name of the resource. type ResultTenancy struct { Namespace string diff --git a/agent/discovery/discovery_test.go b/agent/discovery/discovery_test.go index af7fd148b3..a53ec7b866 100644 --- a/agent/discovery/discovery_test.go +++ b/agent/discovery/discovery_test.go @@ -26,9 +26,9 @@ var ( } testResult = &Result{ - Address: "1.2.3.4", + Node: &Location{Address: "1.2.3.4"}, Type: ResultTypeNode, // This isn't correct for some test cases, but we are only asserting the right data fetcher functions are called - Target: "foo", + Service: &Location{Name: "foo"}, } ) diff --git a/agent/discovery/query_fetcher_v1.go b/agent/discovery/query_fetcher_v1.go index 81c73dca66..c3146a48ac 100644 --- a/agent/discovery/query_fetcher_v1.go +++ b/agent/discovery/query_fetcher_v1.go @@ -33,6 +33,10 @@ type v1DataFetcherDynamicConfig struct { // Default request tenancy datacenter string + segmentName string + nodeName string + nodePartition string + // Catalog configuration allowStale bool maxStale time.Duration @@ -115,17 +119,19 @@ func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, e } results := make([]*Result, 0, 1) - node := out.NodeServices.Node + n := out.NodeServices.Node results = append(results, &Result{ - Address: node.Address, + Node: &Location{ + Name: n.Node, + Address: n.Address, + }, Type: ResultTypeNode, - Metadata: node.Meta, - Target: node.Node, + Metadata: n.Meta, Tenancy: ResultTenancy{ // Namespace is not required because nodes are not namespaced - Partition: node.GetEnterpriseMeta().PartitionOrDefault(), - Datacenter: node.Datacenter, + Partition: n.GetEnterpriseMeta().PartitionOrDefault(), + Datacenter: n.Datacenter, }, }) @@ -163,8 +169,11 @@ func (f *V1DataFetcher) FetchVirtualIP(ctx Context, req *QueryPayload) (*Result, } result := &Result{ - Address: out, - Type: ResultTypeVirtual, + Service: &Location{ + Name: req.Name, + Address: out, + }, + Type: ResultTypeVirtual, } return result, nil } @@ -196,9 +205,11 @@ func (f *V1DataFetcher) FetchRecordsByIp(reqCtx Context, ip net.IP) ([]*Result, for _, n := range out.Nodes { if targetIP == n.Address { results = append(results, &Result{ - Address: n.Address, - Type: ResultTypeNode, - Target: n.Node, + Node: &Location{ + Name: n.Node, + Address: n.Address, + }, + Type: ResultTypeNode, Tenancy: ResultTenancy{ Namespace: f.defaultEnterpriseMeta.NamespaceOrDefault(), Partition: f.defaultEnterpriseMeta.PartitionOrDefault(), @@ -226,13 +237,19 @@ func (f *V1DataFetcher) FetchRecordsByIp(reqCtx Context, ip net.IP) ([]*Result, for _, n := range sout.ServiceNodes { if n.ServiceAddress == targetIP { results = append(results, &Result{ - Address: n.ServiceAddress, - Type: ResultTypeService, - Target: n.ServiceName, + Service: &Location{ + Name: n.ServiceName, + Address: n.ServiceAddress, + }, + Type: ResultTypeService, + Node: &Location{ + Name: n.Node, + Address: n.Address, + }, Tenancy: ResultTenancy{ - Namespace: f.defaultEnterpriseMeta.NamespaceOrDefault(), - Partition: f.defaultEnterpriseMeta.PartitionOrDefault(), - Datacenter: configCtx.datacenter, + Namespace: n.NamespaceOrEmpty(), + Partition: n.PartitionOrEmpty(), + Datacenter: n.Datacenter, }, }) return results, nil @@ -256,7 +273,119 @@ func (f *V1DataFetcher) FetchWorkload(ctx Context, req *QueryPayload) (*Result, // FetchPreparedQuery evaluates the results of a prepared query. // deprecated in V2 func (f *V1DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error) { - return nil, nil + cfg := f.dynamicConfig.Load().(*v1DataFetcherDynamicConfig) + + // Execute the prepared query. + args := structs.PreparedQueryExecuteRequest{ + Datacenter: req.Tenancy.Datacenter, + QueryIDOrName: req.Name, + QueryOptions: structs.QueryOptions{ + Token: ctx.Token, + AllowStale: cfg.allowStale, + MaxAge: cfg.cacheMaxAge, + }, + + // Always pass the local agent through. In the DNS interface, there + // is no provision for passing additional query parameters, so we + // send the local agent's data through to allow distance sorting + // relative to ourself on the server side. + Agent: structs.QuerySource{ + Datacenter: cfg.datacenter, + Segment: cfg.segmentName, + Node: cfg.nodeName, + NodePartition: cfg.nodePartition, + }, + Source: structs.QuerySource{ + Ip: req.SourceIP.String(), + }, + } + + out, err := f.executePreparedQuery(cfg, args) + if err != nil { + return nil, err + } + + // (v2-dns) TODO: (v2-dns) get TTLS working. They come from the database so not having + // TTL on the discovery result poses challenges. + + /* + // TODO (slackpad) - What's a safe limit we can set here? It seems like + // with dup filtering done at this level we need to get everything to + // match the previous behavior. We can optimize by pushing more filtering + // into the query execution, but for now I think we need to get the full + // response. We could also choose a large arbitrary number that will + // likely work in practice, like 10*maxUDPAnswerLimit which should help + // reduce bandwidth if there are thousands of nodes available. + // Determine the TTL. The parse should never fail since we vet it when + // the query is created, but we check anyway. If the query didn't + // specify a TTL then we will try to use the agent's service-specific + // TTL configs. + var ttl time.Duration + if out.DNS.TTL != "" { + var err error + ttl, err = time.ParseDuration(out.DNS.TTL) + if err != nil { + f.logger.Warn("Failed to parse TTL for prepared query , ignoring", + "ttl", out.DNS.TTL, + "prepared_query", req.Name, + ) + } + } else { + ttl, _ = cfg.GetTTLForService(out.Service) + } + */ + + // If we have no nodes, return not found! + if len(out.Nodes) == 0 { + return nil, ErrNoData + } + + // Perform a random shuffle + out.Nodes.Shuffle() + return f.buildResultsFromServiceNodes(out.Nodes), nil +} + +// executePreparedQuery is used to execute a PreparedQuery against the Consul catalog. +// If the config is set to UseCache, it will use agent cache. +func (f *V1DataFetcher) executePreparedQuery(cfg *v1DataFetcherDynamicConfig, args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) { + var out structs.PreparedQueryExecuteResponse + +RPC: + if cfg.useCache { + raw, m, err := f.getFromCacheFunc(context.TODO(), cachetype.PreparedQueryName, &args) + if err != nil { + return nil, err + } + reply, ok := raw.(*structs.PreparedQueryExecuteResponse) + if !ok { + // This should never happen, but we want to protect against panics + return nil, err + } + + f.logger.Trace("cache results for prepared query", + "cache_hit", m.Hit, + "prepared_query", args.QueryIDOrName, + ) + + out = *reply + } else { + if err := f.rpcFunc(context.Background(), "PreparedQuery.Execute", &args, &out); err != nil { + return nil, err + } + } + + // Verify that request is not too stale, redo the request. + if args.AllowStale { + if out.LastContact > cfg.maxStale { + args.AllowStale = false + f.logger.Warn("Query results too stale, re-requesting") + goto RPC + } else if out.LastContact > staleCounterThreshold { + metrics.IncrCounter([]string{"dns", "stale_queries"}, 1) + } + } + + return &out, nil } func (f *V1DataFetcher) ValidateRequest(_ Context, req *QueryPayload) error { @@ -269,6 +398,34 @@ func (f *V1DataFetcher) ValidateRequest(_ Context, req *QueryPayload) error { return validateEnterpriseTenancy(req.Tenancy) } +// buildResultsFromServiceNodes builds a list of results from a list of nodes. +func (f *V1DataFetcher) buildResultsFromServiceNodes(nodes []structs.CheckServiceNode) []*Result { + results := make([]*Result, 0) + for _, n := range nodes { + + results = append(results, &Result{ + Service: &Location{ + Name: n.Service.Service, + Address: n.Service.Address, + }, + Node: &Location{ + Name: n.Node.Node, + Address: n.Node.Address, + }, + Type: ResultTypeService, + Weight: uint32(findWeight(n)), + PortNumber: uint32(f.translateServicePortFunc(n.Node.Datacenter, n.Service.Port, n.Service.TaggedAddresses)), + Metadata: n.Node.Meta, + Tenancy: ResultTenancy{ + Namespace: n.Service.NamespaceOrEmpty(), + Partition: n.Service.PartitionOrEmpty(), + Datacenter: n.Node.Datacenter, + }, + }) + } + return results +} + // fetchNode is used to look up a node in the Consul catalog within NodeServices. // If the config is set to UseCache, it will get the record from the agent cache. func (f *V1DataFetcher) fetchNode(cfg *v1DataFetcherDynamicConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) { @@ -353,7 +510,12 @@ func (f *V1DataFetcher) fetchServiceBasedOnTenancy(ctx Context, req *QueryPayloa out, _, err := f.rpcFuncForServiceNodes(context.TODO(), args) if err != nil { - return nil, err + return nil, fmt.Errorf("rpc request failed: %w", err) + } + + // If we have no nodes, return not found! + if len(out.Nodes) == 0 { + return nil, ErrNoData } // Filter out any service nodes due to health checks @@ -372,57 +534,7 @@ func (f *V1DataFetcher) fetchServiceBasedOnTenancy(ctx Context, req *QueryPayloa // Perform a random shuffle out.Nodes.Shuffle() - results := make([]*Result, 0, len(out.Nodes)) - for _, node := range out.Nodes { - address, target, resultType := getAddressTargetAndResultType(node) - - results = append(results, &Result{ - Address: address, - Type: resultType, - Target: target, - Weight: uint32(findWeight(node)), - PortNumber: uint32(f.translateServicePortFunc(node.Node.Datacenter, node.Service.Port, node.Service.TaggedAddresses)), - Metadata: node.Node.Meta, - Tenancy: ResultTenancy{ - Namespace: node.Service.NamespaceOrEmpty(), - Partition: node.Service.PartitionOrEmpty(), - Datacenter: node.Node.Datacenter, - }, - }) - } - - return results, nil -} - -// getAddressTargetAndResultType returns the address, target and result type for a check service node. -func getAddressTargetAndResultType(node structs.CheckServiceNode) (string, string, ResultType) { - // Set address and target - // if service address is present, set target and address based on service. - // otherwise get it from the node. - address := node.Service.Address - target := node.Service.Service - resultType := ResultTypeService - - addressIP := net.ParseIP(address) - if addressIP == nil { - resultType = ResultTypeNode - if node.Service.Address != "" { - // cases where service address is foo or foo.node.consul - // For usage in DNS, these discovery results necessitate a CNAME record. - // These cases can be inferred from the discovery result when Type is Node and - // target is not an IP. - target = node.Service.Address - } else { - // cases where service address is empty and the service is bound to - // node with an address. These do not require a CNAME record in. - // For usage in DNS, these discovery results do not require a CNAME record. - // These cases can be inferred from the discovery result when Type is Node and - // target is not an IP. - target = node.Node.Node - } - address = node.Node.Address - } - return address, target, resultType + return f.buildResultsFromServiceNodes(out.Nodes), nil } // findWeight returns the weight of a service node. diff --git a/agent/discovery/query_fetcher_v1_test.go b/agent/discovery/query_fetcher_v1_test.go index 703548f3e5..95f3a9fe50 100644 --- a/agent/discovery/query_fetcher_v1_test.go +++ b/agent/discovery/query_fetcher_v1_test.go @@ -51,8 +51,11 @@ func Test_FetchVirtualIP(t *testing.T) { Token: "test-token", }, expectedResult: &Result{ - Address: "192.168.10.10", - Type: ResultTypeVirtual, + Service: &Location{ + Name: "db", + Address: "192.168.10.10", + }, + Type: ResultTypeVirtual, }, expectedErr: nil, }, @@ -97,7 +100,7 @@ func Test_FetchVirtualIP(t *testing.T) { if tc.expectedErr == nil { // set the out parameter to ensure that it is used to formulate the result.Address reply := args.Get(3).(*string) - *reply = tc.expectedResult.Address + *reply = tc.expectedResult.Service.Address } }) // TODO (v2-dns): mock these properly @@ -131,210 +134,62 @@ func Test_FetchEndpoints(t *testing.T) { DNSUseCache: true, DNSCacheMaxAge: 100, } - tests := []struct { - name string - queryPayload *QueryPayload - context Context - rpcFuncForServiceNodes func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) - expectedResults []*Result - expectedErr error - }{ + ctx := Context{ + Token: "test-token", + } + expectedResults := []*Result{ { - name: "when service address is IPv4, result type is service, address is service address and target is service name", - queryPayload: &QueryPayload{ - Name: "service-name", - Tenancy: QueryTenancy{ - Namespace: defaultTestNamespace, - Partition: defaultTestPartition, - }, + Node: &Location{ + Name: "node-name", + Address: "node-address", }, - rpcFuncForServiceNodes: func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) { - return structs.IndexedCheckServiceNodes{ - Nodes: []structs.CheckServiceNode{ - { - Node: &structs.Node{ - Address: "node-address", - Node: "node-name", - Partition: defaultTestPartition, - }, - Service: &structs.NodeService{ - Address: "127.0.0.1", - Service: "service-name", - EnterpriseMeta: acl.NewEnterpriseMetaWithPartition(defaultTestPartition, defaultTestNamespace), - }, - }, - }, - }, cache.ResultMeta{}, nil + Service: &Location{ + Name: "service-name", + Address: "service-address", }, - context: Context{ - Token: "test-token", - }, - expectedResults: []*Result{ - { - Address: "127.0.0.1", - Target: "service-name", - Type: ResultTypeService, - Weight: 1, - Tenancy: ResultTenancy{ - Namespace: defaultTestNamespace, - Partition: defaultTestPartition, - }, - }, - }, - expectedErr: nil, - }, - { - name: "when service address is IPv6, result type is service, address is service address and target is service name", - queryPayload: &QueryPayload{ - Name: "service-name", - Tenancy: QueryTenancy{ - Namespace: defaultTestNamespace, - Partition: defaultTestPartition, - }, - }, - rpcFuncForServiceNodes: func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) { - return structs.IndexedCheckServiceNodes{ - Nodes: []structs.CheckServiceNode{ - { - Node: &structs.Node{ - Address: "node-address", - Node: "node-name", - Partition: defaultTestPartition, - }, - Service: &structs.NodeService{ - Address: "2001:db8:1:2:cafe::1337", - Service: "service-name", - EnterpriseMeta: acl.NewEnterpriseMetaWithPartition(defaultTestPartition, defaultTestNamespace), - }, - }, - }, - }, cache.ResultMeta{}, nil - }, - context: Context{ - Token: "test-token", - }, - expectedResults: []*Result{ - { - Address: "2001:db8:1:2:cafe::1337", - Target: "service-name", - Type: ResultTypeService, - Weight: 1, - Tenancy: ResultTenancy{ - Namespace: defaultTestNamespace, - Partition: defaultTestPartition, - }, - }, - }, - expectedErr: nil, - }, - { - name: "when service address is not IP but is not empty, result type is node, address is node address, and target is service address", - queryPayload: &QueryPayload{ - Name: "service-name", - Tenancy: QueryTenancy{ - Namespace: defaultTestNamespace, - Partition: defaultTestPartition, - }, - }, - rpcFuncForServiceNodes: func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) { - return structs.IndexedCheckServiceNodes{ - Nodes: []structs.CheckServiceNode{ - { - Node: &structs.Node{ - Address: "node-address", - Node: "node-name", - Partition: defaultTestPartition, - }, - Service: &structs.NodeService{ - Address: "foo", - Service: "service-name", - EnterpriseMeta: acl.NewEnterpriseMetaWithPartition(defaultTestPartition, defaultTestNamespace), - }, - }, - }, - }, cache.ResultMeta{}, nil - }, - context: Context{ - Token: "test-token", - }, - expectedResults: []*Result{ - { - Address: "node-address", - Target: "foo", - Type: ResultTypeNode, - Weight: 1, - Tenancy: ResultTenancy{ - Namespace: defaultTestNamespace, - Partition: defaultTestPartition, - }, - }, - }, - expectedErr: nil, - }, - { - name: "when service address is empty, result type is node, address is node address, and target is node name", - queryPayload: &QueryPayload{ - Name: "service-name", - Tenancy: QueryTenancy{ - Namespace: defaultTestNamespace, - Partition: defaultTestPartition, - }, - }, - rpcFuncForServiceNodes: func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) { - return structs.IndexedCheckServiceNodes{ - Nodes: []structs.CheckServiceNode{ - { - Node: &structs.Node{ - Address: "node-address", - Node: "node-name", - Partition: defaultTestPartition, - }, - Service: &structs.NodeService{ - Address: "", - Service: "service-name", - EnterpriseMeta: acl.NewEnterpriseMetaWithPartition(defaultTestPartition, defaultTestNamespace), - }, - }, - }, - }, cache.ResultMeta{}, nil - }, - context: Context{ - Token: "test-token", - }, - expectedResults: []*Result{ - { - Address: "node-address", - Target: "node-name", - Type: ResultTypeNode, - Weight: 1, - Tenancy: ResultTenancy{ - Namespace: defaultTestNamespace, - Partition: defaultTestPartition, - }, - }, - }, - expectedErr: nil, + Type: ResultTypeService, + Weight: 1, }, } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - logger := testutil.Logger(t) - mockRPC := cachetype.NewMockRPC(t) - // TODO (v2-dns): mock these properly - translateServicePortFunc := func(dc string, port int, taggedAddresses map[string]structs.ServiceAddress) int { return 0 } - rpcFuncForSamenessGroup := func(ctx context.Context, req *structs.ConfigEntryQuery) (structs.SamenessGroupConfigEntry, cache.ResultMeta, error) { - return structs.SamenessGroupConfigEntry{}, cache.ResultMeta{}, nil - } - getFromCacheFunc := func(ctx context.Context, t string, r cache.Request) (interface{}, cache.ResultMeta, error) { - return nil, cache.ResultMeta{}, nil - } - - df := NewV1DataFetcher(rc, acl.DefaultEnterpriseMeta(), getFromCacheFunc, mockRPC.RPC, tc.rpcFuncForServiceNodes, rpcFuncForSamenessGroup, translateServicePortFunc, logger) - - results, err := df.FetchEndpoints(tc.context, tc.queryPayload, LookupTypeService) - require.Equal(t, tc.expectedErr, err) - require.Equal(t, tc.expectedResults, results) - }) + logger := testutil.Logger(t) + mockRPC := cachetype.NewMockRPC(t) + // TODO (v2-dns): mock these properly + translateServicePortFunc := func(dc string, port int, taggedAddresses map[string]structs.ServiceAddress) int { return 0 } + rpcFuncForSamenessGroup := func(ctx context.Context, req *structs.ConfigEntryQuery) (structs.SamenessGroupConfigEntry, cache.ResultMeta, error) { + return structs.SamenessGroupConfigEntry{}, cache.ResultMeta{}, nil } + getFromCacheFunc := func(ctx context.Context, t string, r cache.Request) (interface{}, cache.ResultMeta, error) { + return nil, cache.ResultMeta{}, nil + } + rpcFuncForServiceNodes := func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) { + return structs.IndexedCheckServiceNodes{ + Nodes: []structs.CheckServiceNode{ + { + Node: &structs.Node{ + Address: "node-address", + Node: "node-name", + }, + Service: &structs.NodeService{ + Address: "service-address", + Service: "service-name", + }, + }, + }, + }, cache.ResultMeta{}, nil + } + queryPayload := &QueryPayload{ + Name: "service-name", + Tenancy: QueryTenancy{ + Peer: "test-peer", + Namespace: defaultTestNamespace, + Partition: defaultTestPartition, + }, + } + + df := NewV1DataFetcher(rc, acl.DefaultEnterpriseMeta(), getFromCacheFunc, mockRPC.RPC, rpcFuncForServiceNodes, rpcFuncForSamenessGroup, translateServicePortFunc, logger) + + results, err := df.FetchEndpoints(ctx, queryPayload, LookupTypeService) + require.NoError(t, err) + require.Equal(t, expectedResults, results) } diff --git a/agent/discovery/query_fetcher_v2.go b/agent/discovery/query_fetcher_v2.go index 5371b6f4b0..0033c84dff 100644 --- a/agent/discovery/query_fetcher_v2.go +++ b/agent/discovery/query_fetcher_v2.go @@ -6,6 +6,7 @@ package discovery import ( "context" "fmt" + "math/rand" "net" "strings" "sync/atomic" @@ -13,6 +14,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" "github.com/hashicorp/go-hclog" @@ -63,9 +65,62 @@ func (f *V2DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, e } // FetchEndpoints fetches records for A/AAAA/CNAME or SRV requests for services -// TODO (v2-dns): Validate lookupType -func (f *V2DataFetcher) FetchEndpoints(ctx Context, req *QueryPayload, lookupType LookupType) ([]*Result, error) { - return nil, nil +func (f *V2DataFetcher) FetchEndpoints(reqContext Context, req *QueryPayload, lookupType LookupType) ([]*Result, error) { + if lookupType != LookupTypeService { + return nil, ErrNotSupported + } + + configCtx := f.dynamicConfig.Load().(*v2DataFetcherDynamicConfig) + + serviceEndpoints := pbcatalog.ServiceEndpoints{} + resourceObj, err := f.fetchResource(reqContext, *req, pbcatalog.ServiceEndpointsType, &serviceEndpoints) + if err != nil { + return nil, err + } + + // Shuffle the endpoints slice + shuffleFunc := func(i, j int) { + serviceEndpoints.Endpoints[i], serviceEndpoints.Endpoints[j] = serviceEndpoints.Endpoints[j], serviceEndpoints.Endpoints[i] + } + rand.Shuffle(len(serviceEndpoints.Endpoints), shuffleFunc) + + // Convert the service endpoints to results up to the limit + limit := req.Limit + if len(serviceEndpoints.Endpoints) < limit || limit == 0 { + limit = len(serviceEndpoints.Endpoints) + } + + results := make([]*Result, 0, limit) + for idx := 0; idx < limit; idx++ { + endpoint := serviceEndpoints.Endpoints[idx] + + // TODO (v2-dns): filter based on the port name requested + + address, err := f.addressFromWorkloadAddresses(endpoint.Addresses, req.Name) + if err != nil { + return nil, err + } + + weight, ok := getEndpointWeight(endpoint, configCtx) + if !ok { + continue + } + + result := &Result{ + Node: &Location{ + Address: address, + Name: endpoint.GetTargetRef().GetName(), + }, + Type: ResultTypeWorkload, // TODO (v2-dns): I'm not really sure if it's better to have SERVICE OR WORKLOAD here + Tenancy: ResultTenancy{ + Namespace: resourceObj.GetId().GetTenancy().GetNamespace(), + Partition: resourceObj.GetId().GetTenancy().GetPartition(), + }, + Weight: weight, + } + results = append(results, result) + } + return results, nil } // FetchVirtualIP fetches A/AAAA records for virtual IPs @@ -82,55 +137,28 @@ func (f *V2DataFetcher) FetchRecordsByIp(ctx Context, ip net.IP) ([]*Result, err // FetchWorkload is used to fetch a single workload from the V2 catalog. // V2-only. func (f *V2DataFetcher) FetchWorkload(reqContext Context, req *QueryPayload) (*Result, error) { - // Query the resource service for the workload by name and tenancy - resourceReq := pbresource.ReadRequest{ - Id: &pbresource.ID{ - Name: req.Name, - Type: pbcatalog.WorkloadType, - Tenancy: queryTenancyToResourceTenancy(req.Tenancy), - }, + workload := pbcatalog.Workload{} + resourceObj, err := f.fetchResource(reqContext, *req, pbcatalog.WorkloadType, &workload) + if err != nil { + return nil, err } - f.logger.Debug("fetching workload", "name", req.Name) - resourceCtx := metadata.AppendToOutgoingContext(context.Background(), "x-consul-token", reqContext.Token) - - // If the workload is not found, return nil and an error equivalent to NXDOMAIN - response, err := f.client.Read(resourceCtx, &resourceReq) - switch { - case grpcNotFoundErr(err): - f.logger.Debug("workload not found", "name", req.Name) - return nil, ErrNotFound - case err != nil: - f.logger.Error("error fetching workload", "name", req.Name) - return nil, fmt.Errorf("error fetching workload: %w", err) - // default: fallthrough + address, err := f.addressFromWorkloadAddresses(workload.Addresses, req.Name) + if err != nil { + return nil, err } - workload := &pbcatalog.Workload{} - data := response.GetResource().GetData() - if err := data.UnmarshalTo(workload); err != nil { - f.logger.Error("error unmarshalling workload", "name", req.Name) - return nil, fmt.Errorf("error unmarshalling workload: %w", err) - } - - // TODO: (v2-dns): we will need to intelligently return the right workload address based on either the translate - // address setting or the locality of the requester. Workloads must have at least one. - // We also need to make sure that we filter out unix sockets here. - address := workload.Addresses[0].GetHost() - if strings.HasPrefix(address, "unix://") { - f.logger.Error("unix sockets are currently unsupported in workload results", "name", req.Name) - return nil, ErrNotFound - } - - tenancy := response.GetResource().GetId().GetTenancy() + tenancy := resourceObj.GetId().GetTenancy() result := &Result{ - Address: address, - Type: ResultTypeWorkload, + Node: &Location{ + Address: address, + Name: resourceObj.GetId().GetName(), + }, + Type: ResultTypeWorkload, Tenancy: ResultTenancy{ Namespace: tenancy.GetNamespace(), Partition: tenancy.GetPartition(), }, - Target: response.GetResource().GetId().GetName(), } if req.PortName == "" { @@ -169,12 +197,93 @@ func (f *V2DataFetcher) ValidateRequest(_ Context, req *QueryPayload) error { if req.Tag != "" { return ErrNotSupported } - if req.RemoteAddr != nil { + if req.SourceIP != nil { return ErrNotSupported } return nil } +// fetchResource is used to read a single resource from the V2 catalog and cast into a concrete type. +func (f *V2DataFetcher) fetchResource(reqContext Context, req QueryPayload, kind *pbresource.Type, payload proto.Message) (*pbresource.Resource, error) { + // Query the resource service for the ServiceEndpoints by name and tenancy + resourceReq := pbresource.ReadRequest{ + Id: &pbresource.ID{ + Name: req.Name, + Type: kind, + Tenancy: queryTenancyToResourceTenancy(req.Tenancy), + }, + } + + f.logger.Debug("fetching "+kind.String(), "name", req.Name) + resourceCtx := metadata.AppendToOutgoingContext(context.Background(), "x-consul-token", reqContext.Token) + + // If the service is not found, return nil and an error equivalent to NXDOMAIN + response, err := f.client.Read(resourceCtx, &resourceReq) + switch { + case grpcNotFoundErr(err): + f.logger.Debug(kind.String()+" not found", "name", req.Name) + return nil, ErrNotFound + case err != nil: + f.logger.Error("error fetching "+kind.String(), "name", req.Name) + return nil, fmt.Errorf("error fetching %s: %w", kind.String(), err) + // default: fallthrough + } + + data := response.GetResource().GetData() + if err := data.UnmarshalTo(payload); err != nil { + f.logger.Error("error unmarshalling "+kind.String(), "name", req.Name) + return nil, fmt.Errorf("error unmarshalling %s: %w", kind.String(), err) + } + return response.GetResource(), nil +} + +// addressFromWorkloadAddresses returns one address from the workload addresses. +func (f *V2DataFetcher) addressFromWorkloadAddresses(addresses []*pbcatalog.WorkloadAddress, name string) (string, error) { + // TODO: (v2-dns): we will need to intelligently return the right workload address based on either the translate + // address setting or the locality of the requester. Workloads must have at least one. + // We also need to make sure that we filter out unix sockets here. + address := addresses[0].GetHost() + if strings.HasPrefix(address, "unix://") { + f.logger.Error("unix sockets are currently unsupported in workload results", "name", name) + return "", ErrNotFound + } + return address, nil +} + +// getEndpointWeight returns the weight of the endpoint and a boolean indicating if the endpoint should be included +// based on it's health status. +func getEndpointWeight(endpoint *pbcatalog.Endpoint, configCtx *v2DataFetcherDynamicConfig) (uint32, bool) { + health := endpoint.GetHealthStatus().Enum() + if health == nil { + return 0, false + } + + // Filter based on health status and agent config + // This is also a good opportunity to see if SRV weights are set + var weight uint32 + switch *health { + case pbcatalog.Health_HEALTH_PASSING: + weight = endpoint.GetDns().GetWeights().GetPassing() + case pbcatalog.Health_HEALTH_CRITICAL: + return 0, false // always filtered out + case pbcatalog.Health_HEALTH_WARNING: + if configCtx.onlyPassing { + return 0, false // filtered out + } + weight = endpoint.GetDns().GetWeights().GetWarning() + default: + // Everything else can be filtered out + return 0, false + } + + // Important! double-check the weight in the case DNS weights are not set + if weight == 0 { + weight = 1 + } + return weight, true +} + +// queryTenancyToResourceTenancy converts a QueryTenancy to a pbresource.Tenancy. func queryTenancyToResourceTenancy(qTenancy QueryTenancy) *pbresource.Tenancy { rTenancy := resource.DefaultNamespacedTenancy() diff --git a/agent/discovery/query_fetcher_v2_test.go b/agent/discovery/query_fetcher_v2_test.go index 86e2af6b63..9ace561f13 100644 --- a/agent/discovery/query_fetcher_v2_test.go +++ b/agent/discovery/query_fetcher_v2_test.go @@ -21,6 +21,10 @@ import ( "github.com/hashicorp/consul/sdk/testutil" ) +var ( + unknownErr = errors.New("I don't feel so good") +) + // Test_FetchService tests the FetchService method in scenarios where the RPC // call succeeds and fails. func Test_FetchWorkload(t *testing.T) { @@ -29,8 +33,6 @@ func Test_FetchWorkload(t *testing.T) { DNSOnlyPassing: false, } - unknownErr := errors.New("I don't feel so good") - tests := []struct { name string queryPayload *QueryPayload @@ -58,13 +60,12 @@ func Test_FetchWorkload(t *testing.T) { }) }, expectedResult: &Result{ - Address: "1.2.3.4", - Type: ResultTypeWorkload, + Node: &Location{Name: "foo-1234", Address: "1.2.3.4"}, + Type: ResultTypeWorkload, Tenancy: ResultTenancy{ Namespace: resource.DefaultNamespaceName, Partition: resource.DefaultPartitionName, }, - Target: "foo-1234", }, expectedErr: nil, }, @@ -130,7 +131,7 @@ func Test_FetchWorkload(t *testing.T) { }) }, expectedResult: &Result{ - Address: "1.2.3.4", + Node: &Location{Name: "foo-1234", Address: "1.2.3.4"}, Type: ResultTypeWorkload, PortName: "api", PortNumber: 5678, @@ -138,7 +139,6 @@ func Test_FetchWorkload(t *testing.T) { Namespace: resource.DefaultNamespaceName, Partition: resource.DefaultPartitionName, }, - Target: "foo-1234", }, expectedErr: nil, }, @@ -189,13 +189,12 @@ func Test_FetchWorkload(t *testing.T) { }) }, expectedResult: &Result{ - Address: "1.2.3.4", - Type: ResultTypeWorkload, + Node: &Location{Name: "foo-1234", Address: "1.2.3.4"}, + Type: ResultTypeWorkload, Tenancy: ResultTenancy{ Namespace: "test-namespace", Partition: "test-partition", }, - Target: "foo-1234", }, expectedErr: nil, }, @@ -218,6 +217,436 @@ func Test_FetchWorkload(t *testing.T) { } } +// Test_V2FetchEndpoints the FetchService method in scenarios where the RPC +// call succeeds and fails. +func Test_V2FetchEndpoints(t *testing.T) { + + tests := []struct { + name string + queryPayload *QueryPayload + context Context + configureMockClient func(mockClient *mockpbresource.ResourceServiceClient_Expecter) + rc *config.RuntimeConfig + expectedResult []*Result + expectedErr error + verifyShuffle bool + }{ + { + name: "FetchEndpoints returns result", + queryPayload: &QueryPayload{ + Name: "consul", + }, + context: Context{ + Token: "test-token", + }, + configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) { + results := []*pbcatalog.Endpoint{ + makeEndpoint("consul-1", "1.2.3.4", pbcatalog.Health_HEALTH_PASSING, 0, 0), + } + + result := getTestEndpointsResponse(t, "", "", results...) + mockClient.Read(mock.Anything, mock.Anything). + Return(result, nil). + Once(). + Run(func(args mock.Arguments) { + req := args.Get(1).(*pbresource.ReadRequest) + require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name) + }) + }, + expectedResult: []*Result{ + { + Node: &Location{Name: "consul-1", Address: "1.2.3.4"}, + Type: ResultTypeWorkload, + Tenancy: ResultTenancy{ + Namespace: resource.DefaultNamespaceName, + Partition: resource.DefaultPartitionName, + }, + Weight: 1, + }, + }, + }, + { + name: "FetchEndpoints returns empty result with no endpoints", + queryPayload: &QueryPayload{ + Name: "consul", + }, + context: Context{ + Token: "test-token", + }, + configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) { + + result := getTestEndpointsResponse(t, "", "") + mockClient.Read(mock.Anything, mock.Anything). + Return(result, nil). + Once(). + Run(func(args mock.Arguments) { + req := args.Get(1).(*pbresource.ReadRequest) + require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name) + }) + }, + expectedResult: []*Result{}, + }, + { + name: "FetchEndpoints returns a name error when the ServiceEndpoint does not exist", + queryPayload: &QueryPayload{ + Name: "consul", + }, + context: Context{ + Token: "test-token", + }, + configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) { + + result := getTestEndpointsResponse(t, "", "") + mockClient.Read(mock.Anything, mock.Anything). + Return(nil, status.Error(codes.NotFound, "not found")). + Once(). + Run(func(args mock.Arguments) { + req := args.Get(1).(*pbresource.ReadRequest) + require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name) + }) + }, + expectedErr: ErrNotFound, + }, + { + name: "FetchEndpoints encounters a resource client error", + queryPayload: &QueryPayload{ + Name: "consul", + }, + context: Context{ + Token: "test-token", + }, + configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) { + + result := getTestEndpointsResponse(t, "", "") + mockClient.Read(mock.Anything, mock.Anything). + Return(nil, unknownErr). + Once(). + Run(func(args mock.Arguments) { + req := args.Get(1).(*pbresource.ReadRequest) + require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name) + }) + }, + expectedErr: unknownErr, + }, + { + name: "FetchEndpoints always filters out critical endpoints; DNS weights applied correctly", + queryPayload: &QueryPayload{ + Name: "consul", + }, + context: Context{ + Token: "test-token", + }, + configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) { + results := []*pbcatalog.Endpoint{ + makeEndpoint("consul-1", "1.2.3.4", pbcatalog.Health_HEALTH_PASSING, 2, 3), + makeEndpoint("consul-2", "2.3.4.5", pbcatalog.Health_HEALTH_WARNING, 2, 3), + makeEndpoint("consul-3", "3.4.5.6", pbcatalog.Health_HEALTH_CRITICAL, 2, 3), + } + + result := getTestEndpointsResponse(t, "", "", results...) + mockClient.Read(mock.Anything, mock.Anything). + Return(result, nil). + Once(). + Run(func(args mock.Arguments) { + req := args.Get(1).(*pbresource.ReadRequest) + require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name) + }) + }, + expectedResult: []*Result{ + { + Node: &Location{Name: "consul-1", Address: "1.2.3.4"}, + Type: ResultTypeWorkload, + Tenancy: ResultTenancy{ + Namespace: resource.DefaultNamespaceName, + Partition: resource.DefaultPartitionName, + }, + Weight: 2, + }, + { + Node: &Location{Name: "consul-2", Address: "2.3.4.5"}, + Type: ResultTypeWorkload, + Tenancy: ResultTenancy{ + Namespace: resource.DefaultNamespaceName, + Partition: resource.DefaultPartitionName, + }, + Weight: 3, + }, + }, + }, + { + name: "FetchEndpoints filters out warning endpoints when DNSOnlyPassing is true", + queryPayload: &QueryPayload{ + Name: "consul", + }, + context: Context{ + Token: "test-token", + }, + configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) { + results := []*pbcatalog.Endpoint{ + makeEndpoint("consul-1", "1.2.3.4", pbcatalog.Health_HEALTH_PASSING, 2, 3), + makeEndpoint("consul-2", "2.3.4.5", pbcatalog.Health_HEALTH_WARNING, 2, 3), + makeEndpoint("consul-3", "3.4.5.6", pbcatalog.Health_HEALTH_CRITICAL, 2, 3), + } + + result := getTestEndpointsResponse(t, "", "", results...) + mockClient.Read(mock.Anything, mock.Anything). + Return(result, nil). + Once(). + Run(func(args mock.Arguments) { + req := args.Get(1).(*pbresource.ReadRequest) + require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name) + }) + }, + rc: &config.RuntimeConfig{ + DNSOnlyPassing: true, + }, + expectedResult: []*Result{ + { + Node: &Location{Name: "consul-1", Address: "1.2.3.4"}, + Type: ResultTypeWorkload, + Tenancy: ResultTenancy{ + Namespace: resource.DefaultNamespaceName, + Partition: resource.DefaultPartitionName, + }, + Weight: 2, + }, + }, + }, + { + name: "FetchEndpoints shuffles the results", + queryPayload: &QueryPayload{ + Name: "consul", + }, + context: Context{ + Token: "test-token", + }, + configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) { + results := []*pbcatalog.Endpoint{ + // use a set of 10 elements, the odds of getting the same result are 1 in 3628800 + makeEndpoint("consul-1", "10.0.0.1", pbcatalog.Health_HEALTH_PASSING, 0, 0), + makeEndpoint("consul-2", "10.0.0.2", pbcatalog.Health_HEALTH_PASSING, 0, 0), + makeEndpoint("consul-3", "10.0.0.3", pbcatalog.Health_HEALTH_PASSING, 0, 0), + makeEndpoint("consul-4", "10.0.0.4", pbcatalog.Health_HEALTH_PASSING, 0, 0), + makeEndpoint("consul-5", "10.0.0.5", pbcatalog.Health_HEALTH_PASSING, 0, 0), + makeEndpoint("consul-6", "10.0.0.6", pbcatalog.Health_HEALTH_PASSING, 0, 0), + makeEndpoint("consul-7", "10.0.0.7", pbcatalog.Health_HEALTH_PASSING, 0, 0), + makeEndpoint("consul-8", "10.0.0.8", pbcatalog.Health_HEALTH_PASSING, 0, 0), + makeEndpoint("consul-9", "10.0.0.9", pbcatalog.Health_HEALTH_PASSING, 0, 0), + makeEndpoint("consul-10", "10.0.0.10", pbcatalog.Health_HEALTH_PASSING, 0, 0), + } + + result := getTestEndpointsResponse(t, "", "", results...) + mockClient.Read(mock.Anything, mock.Anything). + Return(result, nil). + Once(). + Run(func(args mock.Arguments) { + req := args.Get(1).(*pbresource.ReadRequest) + require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name) + }) + }, + expectedResult: []*Result{ + { + Node: &Location{Name: "consul-1", Address: "10.0.0.1"}, + Type: ResultTypeWorkload, + Tenancy: ResultTenancy{ + Namespace: resource.DefaultNamespaceName, + Partition: resource.DefaultPartitionName, + }, + Weight: 1, + }, + { + Node: &Location{Name: "consul-2", Address: "10.0.0.2"}, + Type: ResultTypeWorkload, + Tenancy: ResultTenancy{ + Namespace: resource.DefaultNamespaceName, + Partition: resource.DefaultPartitionName, + }, + Weight: 1, + }, + { + Node: &Location{Name: "consul-3", Address: "10.0.0.3"}, + Type: ResultTypeWorkload, + Tenancy: ResultTenancy{ + Namespace: resource.DefaultNamespaceName, + Partition: resource.DefaultPartitionName, + }, + Weight: 1, + }, + { + Node: &Location{Name: "consul-4", Address: "10.0.0.4"}, + Type: ResultTypeWorkload, + Tenancy: ResultTenancy{ + Namespace: resource.DefaultNamespaceName, + Partition: resource.DefaultPartitionName, + }, + Weight: 1, + }, + { + Node: &Location{Name: "consul-5", Address: "10.0.0.5"}, + Type: ResultTypeWorkload, + Tenancy: ResultTenancy{ + Namespace: resource.DefaultNamespaceName, + Partition: resource.DefaultPartitionName, + }, + Weight: 1, + }, + { + Node: &Location{Name: "consul-6", Address: "10.0.0.6"}, + Type: ResultTypeWorkload, + Tenancy: ResultTenancy{ + Namespace: resource.DefaultNamespaceName, + Partition: resource.DefaultPartitionName, + }, + Weight: 1, + }, + { + Node: &Location{Name: "consul-7", Address: "10.0.0.7"}, + Type: ResultTypeWorkload, + Tenancy: ResultTenancy{ + Namespace: resource.DefaultNamespaceName, + Partition: resource.DefaultPartitionName, + }, + Weight: 1, + }, + { + Node: &Location{Name: "consul-8", Address: "10.0.0.8"}, + Type: ResultTypeWorkload, + Tenancy: ResultTenancy{ + Namespace: resource.DefaultNamespaceName, + Partition: resource.DefaultPartitionName, + }, + Weight: 1, + }, + { + Node: &Location{Name: "consul-9", Address: "10.0.0.9"}, + Type: ResultTypeWorkload, + Tenancy: ResultTenancy{ + Namespace: resource.DefaultNamespaceName, + Partition: resource.DefaultPartitionName, + }, + Weight: 1, + }, + { + Node: &Location{Name: "consul-10", Address: "10.0.0.10"}, + Type: ResultTypeWorkload, + Tenancy: ResultTenancy{ + Namespace: resource.DefaultNamespaceName, + Partition: resource.DefaultPartitionName, + }, + Weight: 1, + }, + }, + verifyShuffle: true, + }, + { + name: "FetchEndpoints returns only the specified limit", + queryPayload: &QueryPayload{ + Name: "consul", + Limit: 1, + }, + context: Context{ + Token: "test-token", + }, + configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) { + results := []*pbcatalog.Endpoint{ + // intentionally all the same to make this easier to verify + makeEndpoint("consul-1", "10.0.0.1", pbcatalog.Health_HEALTH_PASSING, 0, 0), + makeEndpoint("consul-1", "10.0.0.1", pbcatalog.Health_HEALTH_PASSING, 0, 0), + makeEndpoint("consul-1", "10.0.0.1", pbcatalog.Health_HEALTH_PASSING, 0, 0), + } + + result := getTestEndpointsResponse(t, "", "", results...) + mockClient.Read(mock.Anything, mock.Anything). + Return(result, nil). + Once(). + Run(func(args mock.Arguments) { + req := args.Get(1).(*pbresource.ReadRequest) + require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name) + }) + }, + expectedResult: []*Result{ + { + Node: &Location{Name: "consul-1", Address: "10.0.0.1"}, + Type: ResultTypeWorkload, + Tenancy: ResultTenancy{ + Namespace: resource.DefaultNamespaceName, + Partition: resource.DefaultPartitionName, + }, + Weight: 1, + }, + }, + }, + { + name: "FetchEndpoints returns results with non-default tenancy", + queryPayload: &QueryPayload{ + Name: "consul", + Tenancy: QueryTenancy{ + Namespace: "test-namespace", + Partition: "test-partition", + }, + }, + context: Context{ + Token: "test-token", + }, + configureMockClient: func(mockClient *mockpbresource.ResourceServiceClient_Expecter) { + results := []*pbcatalog.Endpoint{ + // intentionally all the same to make this easier to verify + makeEndpoint("consul-1", "10.0.0.1", pbcatalog.Health_HEALTH_PASSING, 0, 0), + } + + result := getTestEndpointsResponse(t, "test-namespace", "test-partition", results...) + mockClient.Read(mock.Anything, mock.Anything). + Return(result, nil). + Once(). + Run(func(args mock.Arguments) { + req := args.Get(1).(*pbresource.ReadRequest) + require.Equal(t, result.GetResource().GetId().GetName(), req.Id.Name) + require.Equal(t, result.GetResource().GetId().GetTenancy().GetNamespace(), req.Id.Tenancy.Namespace) + require.Equal(t, result.GetResource().GetId().GetTenancy().GetPartition(), req.Id.Tenancy.Partition) + }) + }, + expectedResult: []*Result{ + { + Node: &Location{Name: "consul-1", Address: "10.0.0.1"}, + Type: ResultTypeWorkload, + Tenancy: ResultTenancy{ + Namespace: "test-namespace", + Partition: "test-partition", + }, + Weight: 1, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + logger := testutil.Logger(t) + + client := mockpbresource.NewResourceServiceClient(t) + mockClient := client.EXPECT() + tc.configureMockClient(mockClient) + + if tc.rc == nil { + tc.rc = &config.RuntimeConfig{ + DNSOnlyPassing: false, + } + } + + df := NewV2DataFetcher(tc.rc, client, logger) + + result, err := df.FetchEndpoints(tc.context, tc.queryPayload, LookupTypeService) + require.True(t, errors.Is(err, tc.expectedErr)) + + if tc.verifyShuffle { + require.NotEqualf(t, tc.expectedResult, result, "expected result to be shuffled. There is a small probability that it shuffled back to the original order. In that case, you may want to play the lottery.") + } + + require.ElementsMatchf(t, tc.expectedResult, result, "elements of results should match") + }) + } +} + func getTestWorkloadResponse(t *testing.T, nsOverride string, partitionOverride string) *pbresource.ReadResponse { workload := &pbcatalog.Workload{ Addresses: []*pbcatalog.WorkloadAddress{ @@ -242,7 +671,61 @@ func getTestWorkloadResponse(t *testing.T, nsOverride string, partitionOverride Id: &pbresource.ID{ Name: "foo-1234", Type: pbcatalog.WorkloadType, - Tenancy: resource.DefaultNamespacedTenancy(), // TODO (v2-dns): tenancy + Tenancy: resource.DefaultNamespacedTenancy(), + }, + Data: data, + }, + } + + if nsOverride != "" { + resp.Resource.Id.Tenancy.Namespace = nsOverride + } + if partitionOverride != "" { + resp.Resource.Id.Tenancy.Partition = partitionOverride + } + + return resp +} + +func makeEndpoint(name string, address string, health pbcatalog.Health, weightPassing, weightWarning uint32) *pbcatalog.Endpoint { + endpoint := &pbcatalog.Endpoint{ + Addresses: []*pbcatalog.WorkloadAddress{ + { + Host: address, + }, + }, + HealthStatus: health, + TargetRef: &pbresource.ID{ + Name: name, + }, + } + + if weightPassing > 0 || weightWarning > 0 { + endpoint.Dns = &pbcatalog.DNSPolicy{ + Weights: &pbcatalog.Weights{ + Passing: weightPassing, + Warning: weightWarning, + }, + } + } + + return endpoint +} + +func getTestEndpointsResponse(t *testing.T, nsOverride string, partitionOverride string, endpoints ...*pbcatalog.Endpoint) *pbresource.ReadResponse { + serviceEndpoints := &pbcatalog.ServiceEndpoints{ + Endpoints: endpoints, + } + + data, err := anypb.New(serviceEndpoints) + require.NoError(t, err) + + resp := &pbresource.ReadResponse{ + Resource: &pbresource.Resource{ + Id: &pbresource.ID{ + Name: "consul", + Type: pbcatalog.ServiceType, + Tenancy: resource.DefaultNamespacedTenancy(), }, Data: data, }, diff --git a/agent/dns/router.go b/agent/dns/router.go index 94732ee5de..405af9ed2a 100644 --- a/agent/dns/router.go +++ b/agent/dns/router.go @@ -42,6 +42,7 @@ var ( errInvalidQuestion = fmt.Errorf("invalid question") errNameNotFound = fmt.Errorf("name not found") errNotImplemented = fmt.Errorf("not implemented") + errQueryNotFound = fmt.Errorf("query not found") errRecursionFailed = fmt.Errorf("recursion failed") trailingSpacesRE = regexp.MustCompile(" +$") @@ -93,7 +94,7 @@ type DiscoveryQueryProcessor interface { // //go:generate mockery --name dnsRecursor --inpackage type dnsRecursor interface { - handle(req *dns.Msg, cfgCtx *RouterDynamicConfig, remoteAddr net.Addr) (*dns.Msg, error) + handle(req *dns.Msg, cfgCtx *RouterDynamicConfig, remoteAddress net.Addr) (*dns.Msg, error) } // Router replaces miekg/dns.ServeMux with a simpler router that only checks for the 2-3 valid domains @@ -126,12 +127,13 @@ func NewRouter(cfg Config) (*Router, error) { logger := cfg.Logger.Named(logging.DNS) router := &Router{ - processor: cfg.Processor, - recursor: newRecursor(logger), - domain: domain, - altDomain: altDomain, - logger: logger, - tokenFunc: cfg.TokenFunc, + processor: cfg.Processor, + recursor: newRecursor(logger), + domain: domain, + altDomain: altDomain, + datacenter: cfg.AgentConfig.Datacenter, + logger: logger, + tokenFunc: cfg.TokenFunc, } if err := router.ReloadConfig(cfg.AgentConfig); err != nil { @@ -160,7 +162,7 @@ func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx Context, return createServerFailureResponse(req, configCtx, false) } - responseDomain, needRecurse := r.parseDomain(req) + responseDomain, needRecurse := r.parseDomain(req.Question[0].Name) if needRecurse && !canRecurse(configCtx) { // This is the same error as an unmatched domain return createRefusedResponse(req) @@ -187,7 +189,7 @@ func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx Context, } reqType := parseRequestType(req) - results, query, err := r.getQueryResults(req, reqCtx, reqType, qName) + results, query, err := r.getQueryResults(req, reqCtx, reqType, qName, remoteAddress) switch { case errors.Is(err, errNameNotFound): r.logger.Error("name not found", "name", qName) @@ -272,7 +274,8 @@ func getTTLForResult(name string, query *discovery.Query, cfg *RouterDynamicConf } // getQueryResults returns a discovery.Result from a DNS message. -func (r *Router) getQueryResults(req *dns.Msg, reqCtx Context, reqType requestType, qName string) ([]*discovery.Result, *discovery.Query, error) { +func (r *Router) getQueryResults(req *dns.Msg, reqCtx Context, reqType requestType, + qName string, remoteAddress net.Addr) ([]*discovery.Result, *discovery.Query, error) { switch reqType { case requestTypeConsul: // This is a special case of discovery.QueryByName where we know that we need to query the consul service @@ -288,14 +291,14 @@ func (r *Router) getQueryResults(req *dns.Msg, reqCtx Context, reqType requestTy // need to add something to disambiguate the empty field. Partition: resource.DefaultPartitionName, }, + Limit: 3, }, - Limit: 3, // TODO (v2-dns): need to thread this through to the backend and make sure we shuffle the results } results, err := r.processor.QueryByName(query, discovery.Context{Token: reqCtx.Token}) return results, query, err case requestTypeName: - query, err := buildQueryFromDNSMessage(req, reqCtx, r.domain, r.altDomain) + query, err := buildQueryFromDNSMessage(req, reqCtx, r.domain, r.altDomain, remoteAddress) if err != nil { r.logger.Error("error building discovery query from DNS request", "error", err) return nil, query, err @@ -303,6 +306,13 @@ func (r *Router) getQueryResults(req *dns.Msg, reqCtx Context, reqType requestTy results, err := r.processor.QueryByName(query, discovery.Context{Token: reqCtx.Token}) if err != nil { r.logger.Error("error processing discovery query", "error", err) + switch err.Error() { + case errNameNotFound.Error(): + return nil, query, errNameNotFound + case errQueryNotFound.Error(): + return nil, query, errQueryNotFound + } + return nil, query, err } return results, query, nil @@ -376,8 +386,8 @@ const ( // it will return true for needRecurse. The logic is based on miekg/dns.ServeDNS matcher. // The implementation assumes that the only valid domains are "consul." and the alternative domain, and // that DS query types are not supported. -func (r *Router) parseDomain(req *dns.Msg) (string, bool) { - target := dns.CanonicalName(req.Question[0].Name) +func (r *Router) parseDomain(questionName string) (string, bool) { + target := dns.CanonicalName(questionName) target, _ = stripSuffix(target) for offset, overflow := 0, false; !overflow; offset, overflow = dns.NextLabel(target, offset) { @@ -786,8 +796,10 @@ func buildAddressResults(req *dns.Msg) ([]*discovery.Result, error) { return []*discovery.Result{ { - Address: ip.String(), - Type: discovery.ResultTypeNode, // We choose node by convention since we do not know the origin of the IP + Node: &discovery.Location{ + Address: ip.String(), + }, + Type: discovery.ResultTypeNode, // We choose node by convention since we do not know the origin of the IP }, }, nil } @@ -796,8 +808,14 @@ func buildAddressResults(req *dns.Msg) ([]*discovery.Result, error) { func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, reqCtx Context, query *discovery.Query, cfg *RouterDynamicConfig, domain string, remoteAddress net.Addr, maxRecursionLevel int) (answer []dns.RR, extra []dns.RR, ns []dns.RR) { - target := newDNSAddress(result.Target) - address := newDNSAddress(result.Address) + serviceAddress := newDNSAddress("") + if result.Service != nil { + serviceAddress = newDNSAddress(result.Service.Address) + } + nodeAddress := newDNSAddress("") + if result.Node != nil { + nodeAddress = newDNSAddress(result.Node.Address) + } qName := req.Question[0].Name ttlLookupName := qName if query != nil { @@ -812,118 +830,183 @@ func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, req switch { // PTR requests are first since they are a special case of domain overriding question type case parseRequestType(req) == requestTypeIP: + ptrTarget := "" + if result.Type == discovery.ResultTypeNode { + ptrTarget = result.Node.Name + } else if result.Type == discovery.ResultTypeService { + ptrTarget = result.Service.Name + } + ptr := &dns.PTR{ Hdr: dns.RR_Header{Name: qName, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: 0}, - Ptr: canonicalNameForResult(result, domain), + Ptr: canonicalNameForResult(result.Type, ptrTarget, domain, result.Tenancy, result.PortName), } answer = append(answer, ptr) case qType == dns.TypeNS: // TODO (v2-dns): fqdn in V1 has the datacenter included, this would need to be added to discovery.Result - fqdn := canonicalNameForResult(result, domain) - extraRecord := makeIPBasedRecord(fqdn, address, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported + fqdn := canonicalNameForResult(result.Type, result.Node.Name, domain, result.Tenancy, result.PortName) + extraRecord := makeIPBasedRecord(fqdn, nodeAddress, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported answer = append(answer, makeNSRecord(domain, fqdn, ttl)) extra = append(extra, extraRecord) case qType == dns.TypeSOA: // TODO (v2-dns): fqdn in V1 has the datacenter included, this would need to be added to discovery.Result // to be returned in the result. - fqdn := canonicalNameForResult(result, domain) - extraRecord := makeIPBasedRecord(fqdn, address, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported + fqdn := canonicalNameForResult(result.Type, result.Node.Name, domain, result.Tenancy, result.PortName) + extraRecord := makeIPBasedRecord(fqdn, nodeAddress, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported ns = append(ns, makeNSRecord(domain, fqdn, ttl)) extra = append(extra, extraRecord) case qType == dns.TypeSRV: // We put A/AAAA/CNAME records in the additional section for SRV requests - a, e := r.getAnswerExtrasForAddressAndTarget(address, target, req, reqCtx, - result, ttl, remoteAddress, cfg, maxRecursionLevel) + a, e := r.getAnswerExtrasForAddressAndTarget(nodeAddress, serviceAddress, req, reqCtx, + result, ttl, remoteAddress, cfg, domain, maxRecursionLevel) answer = append(answer, a...) extra = append(extra, e...) - if cfg.NodeMetaTXT { - name := target.FQDN() - if !target.IsInternalFQDN(r.domain) && !target.IsExternalFQDN(r.domain) { - name = canonicalNameForResult(result, r.domain) - } - extra = append(extra, makeTXTRecord(name, result, ttl)...) - } default: - a, e := r.getAnswerExtrasForAddressAndTarget(address, target, req, reqCtx, - result, ttl, remoteAddress, cfg, maxRecursionLevel) + a, e := r.getAnswerExtrasForAddressAndTarget(nodeAddress, serviceAddress, req, reqCtx, + result, ttl, remoteAddress, cfg, domain, maxRecursionLevel) answer = append(answer, a...) extra = append(extra, e...) } + + a, e := getAnswerAndExtraTXT(req, cfg, qName, result, ttl, domain) + answer = append(answer, a...) + extra = append(extra, e...) return } -// getAnswerExtrasForAddressAndTarget creates the dns answer and extra from address and target dnsAddress pairs. -func (r *Router) getAnswerExtrasForAddressAndTarget(address *dnsAddress, target *dnsAddress, req *dns.Msg, +// getAnswerExtrasForAddressAndTarget creates the dns answer and extra from nodeAddress and serviceAddress dnsAddress pairs. +func (r *Router) getAnswerExtrasForAddressAndTarget(nodeAddress *dnsAddress, serviceAddress *dnsAddress, req *dns.Msg, reqCtx Context, result *discovery.Result, ttl uint32, remoteAddress net.Addr, - cfg *RouterDynamicConfig, maxRecursionLevel int) (answer []dns.RR, extra []dns.RR) { + cfg *RouterDynamicConfig, domain string, maxRecursionLevel int) (answer []dns.RR, extra []dns.RR) { qName := req.Question[0].Name reqType := parseRequestType(req) switch { - // Virtual IPs and Address requests - // both return IPs with empty targets case (reqType == requestTypeAddress || result.Type == discovery.ResultTypeVirtual) && - target.IsEmptyString() && address.IsIP(): - a, e := getAnswerExtrasForIP(qName, address, req.Question[0], reqType, - result, ttl) - answer = append(a, answer...) - extra = append(e, extra...) - - // Address is a FQDN and requires a CNAME lookup. - case address.IsFQDN(): - a, e := r.makeRecordFromFQDN(address.FQDN(), result, req, reqCtx, - cfg, ttl, remoteAddress, maxRecursionLevel) - answer = append(a, answer...) - extra = append(e, extra...) - - // Target is FQDN that point to IP - case target.IsFQDN() && address.IsIP(): - var a, e []dns.RR - if result.Type == discovery.ResultTypeNode || result.Type == discovery.ResultTypeWorkload { - // if it is a node record it means the service address pointed to a node - // and the node address was used. So we create an A record for the node address, - // as well as a CNAME for the service to node mapping. - name := target.FQDN() - if !target.IsInternalFQDN(r.domain) && !target.IsExternalFQDN(r.domain) { - name = canonicalNameForResult(result, r.domain) - } else if target.IsInternalFQDN(r.domain) { - answer = append(answer, makeCNAMERecord(qName, canonicalNameForResult(result, r.domain), ttl)) - } - a, e = getAnswerExtrasForIP(name, address, req.Question[0], reqType, - result, ttl) - } else { - // if it is a service record, it means that the service address had the IP directly - // and there was not a need for an intermediate CNAME. - a, e = getAnswerExtrasForIP(qName, address, req.Question[0], reqType, - result, ttl) - } + serviceAddress.IsEmptyString() && nodeAddress.IsIP(): + a, e := getAnswerExtrasForIP(qName, nodeAddress, req.Question[0], reqType, + result, ttl, domain) answer = append(answer, a...) extra = append(extra, e...) - // The target is a FQDN (internal or external service name) - default: - a, e := r.makeRecordFromFQDN(target.FQDN(), result, req, reqCtx, cfg, + case result.Type == discovery.ResultTypeNode && nodeAddress.IsIP(): + canonicalNodeName := canonicalNameForResult(result.Type, result.Node.Name, domain, result.Tenancy, result.PortName) + a, e := getAnswerExtrasForIP(canonicalNodeName, nodeAddress, req.Question[0], reqType, + result, ttl, domain) + answer = append(answer, a...) + extra = append(extra, e...) + + case result.Type == discovery.ResultTypeNode && !nodeAddress.IsIP(): + a, e := r.makeRecordFromFQDN(serviceAddress.FQDN(), result, req, reqCtx, cfg, ttl, remoteAddress, maxRecursionLevel) - answer = append(a, answer...) - extra = append(e, extra...) + answer = append(answer, a...) + extra = append(extra, e...) + + case serviceAddress.IsEmptyString() && nodeAddress.IsEmptyString(): + return nil, nil + + // There is no service address and the node address is an IP + case serviceAddress.IsEmptyString() && nodeAddress.IsIP(): + canonicalNodeName := canonicalNameForResult(discovery.ResultTypeNode, result.Node.Name, domain, result.Tenancy, result.PortName) + a, e := getAnswerExtrasForIP(canonicalNodeName, nodeAddress, req.Question[0], reqType, + result, ttl, domain) + answer = append(answer, a...) + extra = append(extra, e...) + + // There is no service address and the node address is a FQDN (external service) + case serviceAddress.IsEmptyString(): + a, e := r.makeRecordFromFQDN(nodeAddress.FQDN(), result, req, reqCtx, cfg, + ttl, remoteAddress, maxRecursionLevel) + answer = append(answer, a...) + extra = append(extra, e...) + + // The service address is an IP + case serviceAddress.IsIP(): + canonicalServiceName := canonicalNameForResult(discovery.ResultTypeService, result.Service.Name, domain, result.Tenancy, result.PortName) + a, e := getAnswerExtrasForIP(canonicalServiceName, serviceAddress, req.Question[0], reqType, + result, ttl, domain) + answer = append(answer, a...) + extra = append(extra, e...) + + // If the service address is a CNAME for the service we are looking + // for then use the node address. + case serviceAddress.FQDN() == req.Question[0].Name && nodeAddress.IsIP(): + canonicalNodeName := canonicalNameForResult(discovery.ResultTypeNode, result.Node.Name, domain, result.Tenancy, result.PortName) + a, e := getAnswerExtrasForIP(canonicalNodeName, nodeAddress, req.Question[0], reqType, + result, ttl, domain) + answer = append(answer, a...) + extra = append(extra, e...) + + // The service address is a FQDN (internal or external service name) + default: + a, e := r.makeRecordFromFQDN(serviceAddress.FQDN(), result, req, reqCtx, cfg, + ttl, remoteAddress, maxRecursionLevel) + answer = append(answer, a...) + extra = append(extra, e...) } + return } +// getAnswerAndExtraTXT determines whether a TXT needs to be create and then +// returns the TXT record in the answer or extra depending on the question type. +func getAnswerAndExtraTXT(req *dns.Msg, cfg *RouterDynamicConfig, qName string, + result *discovery.Result, ttl uint32, domain string) (answer []dns.RR, extra []dns.RR) { + recordHeaderName := qName + serviceAddress := newDNSAddress("") + if result.Service != nil { + serviceAddress = newDNSAddress(result.Service.Address) + } + if result.Type != discovery.ResultTypeNode && + result.Type != discovery.ResultTypeVirtual && + !serviceAddress.IsInternalFQDN(domain) && + !serviceAddress.IsExternalFQDN(domain) { + recordHeaderName = canonicalNameForResult(discovery.ResultTypeNode, result.Node.Name, + domain, result.Tenancy, result.PortName) + } + qType := req.Question[0].Qtype + generateMeta := false + metaInAnswer := false + if qType == dns.TypeANY || qType == dns.TypeTXT { + generateMeta = true + metaInAnswer = true + } else if cfg.NodeMetaTXT { + generateMeta = true + } + + // Do not generate txt records if we don't have to: https://github.com/hashicorp/consul/pull/5272 + if generateMeta { + meta := makeTXTRecord(recordHeaderName, result, ttl) + if metaInAnswer { + answer = append(answer, meta...) + } else { + extra = append(extra, meta...) + } + } + return answer, extra +} + // getAnswerExtrasForIP creates the dns answer and extra from IP dnsAddress pairs. func getAnswerExtrasForIP(name string, addr *dnsAddress, question dns.Question, - reqType requestType, result *discovery.Result, ttl uint32) (answer []dns.RR, extra []dns.RR) { + reqType requestType, result *discovery.Result, ttl uint32, _ string) (answer []dns.RR, extra []dns.RR) { qType := question.Qtype // Have to pass original question name here even if the system has recursed // and stripped off the domain suffix. recHdrName := question.Name if qType == dns.TypeSRV { - recHdrName = name + nameSplit := strings.Split(name, ".") + if len(nameSplit) > 1 && nameSplit[1] == addrLabel { + recHdrName = name + } else { + recHdrName = name + } + name = question.Name } + record := makeIPBasedRecord(recHdrName, addr, ttl) isARecordWhenNotExplicitlyQueried := record.Header().Rrtype == dns.TypeA && qType != dns.TypeA && qType != dns.TypeANY @@ -938,12 +1021,28 @@ func getAnswerExtrasForIP(name string, addr *dnsAddress, question dns.Question, } if reqType != requestTypeAddress && qType == dns.TypeSRV { - srv := makeSRVRecord(name, name, result, ttl) + srv := makeSRVRecord(name, recHdrName, result, ttl) answer = append(answer, srv) } return } +// encodeIPAsFqdn encodes an IP address as a FQDN. +func encodeIPAsFqdn(result *discovery.Result, ip net.IP, responseDomain string) string { + ipv4 := ip.To4() + ipStr := hex.EncodeToString(ip) + if ipv4 != nil { + ipStr = ipStr[len(ipStr)-(net.IPv4len*2):] + } + if result.Tenancy.PeerName != "" { + // Exclude the datacenter from the FQDN on the addr for peers. + // This technically makes no difference, since the addr endpoint ignores the DC + // component of the request, but do it anyway for a less confusing experience. + return fmt.Sprintf("%s.addr.%s", ipStr, responseDomain) + } + return fmt.Sprintf("%s.addr.%s.%s", ipStr, result.Tenancy.Datacenter, responseDomain) +} + func makeSOARecord(domain string, cfg *RouterDynamicConfig) dns.RR { return &dns.SOA{ Hdr: dns.RR_Header{ @@ -1016,7 +1115,7 @@ func (r *Router) makeRecordFromFQDN(fqdn string, result *discovery.Result, MORE_REC: for _, rr := range more { switch rr.Header().Rrtype { - case dns.TypeCNAME, dns.TypeA, dns.TypeAAAA: + case dns.TypeCNAME, dns.TypeA, dns.TypeAAAA, dns.TypeTXT: // set the TTL manually rr.Header().Ttl = ttl additional = append(additional, rr) @@ -1035,8 +1134,15 @@ MORE_REC: return answers, additional } + address := "" + if result.Service != nil && result.Service.Address != "" { + address = result.Service.Address + } else if result.Node != nil { + address = result.Node.Address + } + answers := []dns.RR{ - makeCNAMERecord(q.Name, result.Target, ttl), + makeCNAMERecord(q.Name, address, ttl), } answers = append(answers, additional...) diff --git a/agent/dns/router_ce.go b/agent/dns/router_ce.go index 3a44ca1cdc..67cab00490 100644 --- a/agent/dns/router_ce.go +++ b/agent/dns/router_ce.go @@ -12,26 +12,27 @@ import ( ) // canonicalNameForResult returns the canonical name for a discovery result. -func canonicalNameForResult(result *discovery.Result, domain string) string { - switch result.Type { +func canonicalNameForResult(resultType discovery.ResultType, target, domain string, + tenancy discovery.ResultTenancy, portName string) string { + switch resultType { case discovery.ResultTypeService: - return fmt.Sprintf("%s.%s.%s.%s", result.Target, "service", result.Tenancy.Datacenter, domain) + return fmt.Sprintf("%s.%s.%s.%s", target, "service", tenancy.Datacenter, domain) case discovery.ResultTypeNode: - if result.Tenancy.PeerName != "" { + if tenancy.PeerName != "" { // We must return a more-specific DNS name for peering so // that there is no ambiguity with lookups. return fmt.Sprintf("%s.node.%s.peer.%s", - result.Target, - result.Tenancy.PeerName, + target, + tenancy.PeerName, domain) } // Return a simpler format for non-peering nodes. - return fmt.Sprintf("%s.node.%s.%s", result.Target, result.Tenancy.Datacenter, domain) + return fmt.Sprintf("%s.node.%s.%s", target, tenancy.Datacenter, domain) case discovery.ResultTypeWorkload: - if result.PortName != "" { - return fmt.Sprintf("%s.port.%s.workload.%s", result.PortName, result.Target, domain) + if portName != "" { + return fmt.Sprintf("%s.port.%s.workload.%s", portName, target, domain) } - return fmt.Sprintf("%s.workload.%s", result.Target, domain) + return fmt.Sprintf("%s.workload.%s", target, domain) } return "" } diff --git a/agent/dns/router_ce_test.go b/agent/dns/router_ce_test.go index 69f73e2dbf..3dd63eeee6 100644 --- a/agent/dns/router_ce_test.go +++ b/agent/dns/router_ce_test.go @@ -37,9 +37,9 @@ func getAdditionalTestCases(t *testing.T) []HandleTestCase { configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { results := []*discovery.Result{ { - Address: "1.2.3.4", + Node: &discovery.Location{Name: "foo", Address: "1.2.3.4"}, Type: discovery.ResultTypeNode, - Target: "foo", + Service: &discovery.Location{Name: "foo", Address: "foo"}, Tenancy: discovery.ResultTenancy{ Datacenter: "dc2", PeerName: "peer1", @@ -100,9 +100,9 @@ func getAdditionalTestCases(t *testing.T) []HandleTestCase { configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { results := []*discovery.Result{ { - Address: "1.2.3.4", + Node: &discovery.Location{Name: "foo", Address: "1.2.3.4"}, + Service: &discovery.Location{Name: "foo", Address: "foo"}, Type: discovery.ResultTypeService, - Target: "foo", Tenancy: discovery.ResultTenancy{ Datacenter: "dc2", }, diff --git a/agent/dns/router_query.go b/agent/dns/router_query.go index 13a4935be0..6576b8724b 100644 --- a/agent/dns/router_query.go +++ b/agent/dns/router_query.go @@ -4,6 +4,7 @@ package dns import ( + "net" "strings" "github.com/miekg/dns" @@ -12,7 +13,8 @@ import ( ) // buildQueryFromDNSMessage returns a discovery.Query from a DNS message. -func buildQueryFromDNSMessage(req *dns.Msg, reqCtx Context, domain, altDomain string) (*discovery.Query, error) { +func buildQueryFromDNSMessage(req *dns.Msg, reqCtx Context, domain, altDomain string, + remoteAddress net.Addr) (*discovery.Query, error) { queryType, queryParts, querySuffixes := getQueryTypePartsAndSuffixesFromDNSMessage(req, domain, altDomain) queryTenancy, err := getQueryTenancy(reqCtx, queryType, querySuffixes) @@ -36,16 +38,20 @@ func buildQueryFromDNSMessage(req *dns.Msg, reqCtx Context, domain, altDomain st Tenancy: queryTenancy, Tag: tag, PortName: portName, - //RemoteAddr: nil, // TODO (v2-dns): Prepared Queries for V1 Catalog + SourceIP: getSourceIP(req, queryType, remoteAddress), }, }, nil } // getQueryNameAndTagFromParts returns the query name and tag from the query parts that are taken from the original dns question. func getQueryNameAndTagFromParts(queryType discovery.QueryType, queryParts []string) (string, string) { + n := len(queryParts) + if n == 0 { + return "", "" + } + switch queryType { case discovery.QueryTypeService: - n := len(queryParts) // Support RFC 2782 style syntax if n == 2 && strings.HasPrefix(queryParts[1], "_") && strings.HasPrefix(queryParts[0], "_") { // Grab the tag since we make nuke it if it's tcp @@ -60,9 +66,9 @@ func getQueryNameAndTagFromParts(queryType discovery.QueryType, queryParts []str // _name._tag.service.consul return name, tag } - return queryParts[len(queryParts)-1], "" + return queryParts[n-1], "" } - return queryParts[len(queryParts)-1], "" + return queryParts[n-1], "" } // getQueryTenancy returns a discovery.QueryTenancy from a DNS message. @@ -177,3 +183,24 @@ func getQueryTypeFromLabels(label string) discovery.QueryType { return discovery.QueryTypeInvalid } } + +// getSourceIP returns the source IP from the dns request. +func getSourceIP(req *dns.Msg, queryType discovery.QueryType, remoteAddr net.Addr) (sourceIP net.IP) { + if queryType == discovery.QueryTypePreparedQuery { + subnet := ednsSubnetForRequest(req) + + if subnet != nil { + sourceIP = subnet.Address + } else { + switch v := remoteAddr.(type) { + case *net.UDPAddr: + sourceIP = v.IP + case *net.TCPAddr: + sourceIP = v.IP + case *net.IPAddr: + sourceIP = v.IP + } + } + } + return sourceIP +} diff --git a/agent/dns/router_query_test.go b/agent/dns/router_query_test.go index dc4ea6592e..94182de9e0 100644 --- a/agent/dns/router_query_test.go +++ b/agent/dns/router_query_test.go @@ -206,7 +206,7 @@ func Test_buildQueryFromDNSMessage(t *testing.T) { if context == nil { context = &Context{} } - query, err := buildQueryFromDNSMessage(tc.request, *context, "consul.", ".") + query, err := buildQueryFromDNSMessage(tc.request, *context, "consul.", ".", nil) require.NoError(t, err) assert.Equal(t, tc.expectedQuery, query) }) diff --git a/agent/dns/router_test.go b/agent/dns/router_test.go index aa38d91ef3..220ae27f38 100644 --- a/agent/dns/router_test.go +++ b/agent/dns/router_test.go @@ -21,8 +21,6 @@ import ( "github.com/hashicorp/consul/agent/structs" ) -// TODO (v2-dns) - // TBD Test Cases // 1. Reload the configuration (e.g. SOA) // 2. Something to check the token makes it through to the data fetcher @@ -717,8 +715,8 @@ func Test_HandleRequest(t *testing.T) { configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { fetcher.(*discovery.MockCatalogDataFetcher).On("FetchVirtualIP", mock.Anything, mock.Anything).Return(&discovery.Result{ - Address: "240.0.0.2", - Type: discovery.ResultTypeVirtual, + Node: &discovery.Location{Address: "240.0.0.2"}, + Type: discovery.ResultTypeVirtual, }, nil) }, validateAndNormalizeExpected: true, @@ -767,8 +765,8 @@ func Test_HandleRequest(t *testing.T) { configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { fetcher.(*discovery.MockCatalogDataFetcher).On("FetchVirtualIP", mock.Anything, mock.Anything).Return(&discovery.Result{ - Address: "2001:db8:1:2:cafe::1337", - Type: discovery.ResultTypeVirtual, + Node: &discovery.Location{Address: "2001:db8:1:2:cafe::1337"}, + Type: discovery.ResultTypeVirtual, }, nil) }, validateAndNormalizeExpected: true, @@ -819,14 +817,14 @@ func Test_HandleRequest(t *testing.T) { On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything). Return([]*discovery.Result{ { - Address: "1.2.3.4", + Node: &discovery.Location{Name: "server-one", Address: "1.2.3.4"}, + Service: &discovery.Location{Name: "service-one", Address: "server-one"}, Type: discovery.ResultTypeWorkload, - Target: "server-one", // This would correlate to the workload name }, { - Address: "4.5.6.7", + Node: &discovery.Location{Name: "server-two", Address: "4.5.6.7"}, + Service: &discovery.Location{Name: "service-one", Address: "server-two"}, Type: discovery.ResultTypeWorkload, - Target: "server-two", // This would correlate to the workload name }, }, nil). Run(func(args mock.Arguments) { @@ -835,6 +833,7 @@ func Test_HandleRequest(t *testing.T) { require.Equal(t, discovery.LookupTypeService, reqType) require.Equal(t, structs.ConsulServiceName, req.Name) + require.Equal(t, 3, req.Limit) }) }, validateAndNormalizeExpected: true, @@ -941,14 +940,14 @@ func Test_HandleRequest(t *testing.T) { On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything). Return([]*discovery.Result{ { - Address: "1.2.3.4", + Node: &discovery.Location{Name: "server-one", Address: "1.2.3.4"}, + Service: &discovery.Location{Name: "service-one", Address: "server-one"}, Type: discovery.ResultTypeWorkload, - Target: "server-one", // This would correlate to the workload name }, { - Address: "4.5.6.7", + Node: &discovery.Location{Name: "server-two", Address: "4.5.6.7"}, + Service: &discovery.Location{Name: "service-two", Address: "server-two"}, Type: discovery.ResultTypeWorkload, - Target: "server-two", // This would correlate to the workload name }, }, nil). Run(func(args mock.Arguments) { @@ -957,6 +956,7 @@ func Test_HandleRequest(t *testing.T) { require.Equal(t, discovery.LookupTypeService, reqType) require.Equal(t, structs.ConsulServiceName, req.Name) + require.Equal(t, 3, req.Limit) }) }, validateAndNormalizeExpected: true, @@ -1033,6 +1033,204 @@ func Test_HandleRequest(t *testing.T) { }, }, }, + // NS Queries + { + name: "vanilla NS query", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "consul.", + Qtype: dns.TypeNS, + Qclass: dns.ClassINET, + }, + }, + }, + configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { + fetcher.(*discovery.MockCatalogDataFetcher). + On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything). + Return([]*discovery.Result{ + { + Node: &discovery.Location{Name: "server-one", Address: "1.2.3.4"}, + Type: discovery.ResultTypeWorkload, + }, + { + Node: &discovery.Location{Name: "server-two", Address: "4.5.6.7"}, + Type: discovery.ResultTypeWorkload, + }, + }, nil). + Run(func(args mock.Arguments) { + req := args.Get(1).(*discovery.QueryPayload) + reqType := args.Get(2).(discovery.LookupType) + + require.Equal(t, discovery.LookupTypeService, reqType) + require.Equal(t, structs.ConsulServiceName, req.Name) + require.Equal(t, 3, req.Limit) + }) + }, + validateAndNormalizeExpected: true, + response: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + Response: true, + Authoritative: true, + }, + Compress: true, + Question: []dns.Question{ + { + Name: "consul.", + Qtype: dns.TypeNS, + Qclass: dns.ClassINET, + }, + }, + Answer: []dns.RR{ + &dns.NS{ + Hdr: dns.RR_Header{ + Name: "consul.", + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + Ttl: 123, + }, + Ns: "server-one.workload.consul.", // TODO (v2-dns): this format needs to be consistent with other workloads + }, + &dns.NS{ + Hdr: dns.RR_Header{ + Name: "consul.", + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + Ttl: 123, + }, + Ns: "server-two.workload.consul.", + }, + }, + Extra: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{ + Name: "server-one.workload.consul.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 123, + }, + A: net.ParseIP("1.2.3.4"), + }, + &dns.A{ + Hdr: dns.RR_Header{ + Name: "server-two.workload.consul.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 123, + }, + A: net.ParseIP("4.5.6.7"), + }, + }, + }, + }, + { + name: "NS query against alternate domain", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "testdomain.", + Qtype: dns.TypeNS, + Qclass: dns.ClassINET, + }, + }, + }, + agentConfig: &config.RuntimeConfig{ + DNSDomain: "consul", + DNSAltDomain: "testdomain", + DNSNodeTTL: 123 * time.Second, + DNSSOA: config.RuntimeSOAConfig{ + Refresh: 1, + Retry: 2, + Expire: 3, + Minttl: 4, + }, + }, + configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { + fetcher.(*discovery.MockCatalogDataFetcher). + On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything). + Return([]*discovery.Result{ + { + Node: &discovery.Location{Name: "server-one", Address: "1.2.3.4"}, + Type: discovery.ResultTypeWorkload, + }, + { + Node: &discovery.Location{Name: "server-two", Address: "4.5.6.7"}, + Type: discovery.ResultTypeWorkload, + }, + }, nil). + Run(func(args mock.Arguments) { + req := args.Get(1).(*discovery.QueryPayload) + reqType := args.Get(2).(discovery.LookupType) + + require.Equal(t, discovery.LookupTypeService, reqType) + require.Equal(t, structs.ConsulServiceName, req.Name) + require.Equal(t, 3, req.Limit) + }) + }, + validateAndNormalizeExpected: true, + response: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + Response: true, + Authoritative: true, + }, + Compress: true, + Question: []dns.Question{ + { + Name: "testdomain.", + Qtype: dns.TypeNS, + Qclass: dns.ClassINET, + }, + }, + Answer: []dns.RR{ + &dns.NS{ + Hdr: dns.RR_Header{ + Name: "testdomain.", + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + Ttl: 123, + }, + Ns: "server-one.workload.testdomain.", + }, + &dns.NS{ + Hdr: dns.RR_Header{ + Name: "testdomain.", + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + Ttl: 123, + }, + Ns: "server-two.workload.testdomain.", + }, + }, + Extra: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{ + Name: "server-one.workload.testdomain.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 123, + }, + A: net.ParseIP("1.2.3.4"), + }, + &dns.A{ + Hdr: dns.RR_Header{ + Name: "server-two.workload.testdomain.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 123, + }, + A: net.ParseIP("4.5.6.7"), + }, + }, + }, + }, // PTR Lookups { name: "PTR lookup for node, query type is ANY", @@ -1051,9 +1249,9 @@ func Test_HandleRequest(t *testing.T) { configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { results := []*discovery.Result{ { - Address: "1.2.3.4", + Node: &discovery.Location{Name: "foo", Address: "1.2.3.4"}, + Service: &discovery.Location{Name: "bar", Address: "foo"}, Type: discovery.ResultTypeNode, - Target: "foo", Tenancy: discovery.ResultTenancy{ Datacenter: "dc2", }, @@ -1113,9 +1311,9 @@ func Test_HandleRequest(t *testing.T) { configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { results := []*discovery.Result{ { - Address: "2001:db8::567:89ab", + Node: &discovery.Location{Name: "foo", Address: "2001:db8::567:89ab"}, + Service: &discovery.Location{Name: "web", Address: "foo"}, Type: discovery.ResultTypeNode, - Target: "foo", Tenancy: discovery.ResultTenancy{ Datacenter: "dc2", }, @@ -1315,6 +1513,85 @@ func Test_HandleRequest(t *testing.T) { }, }, }, + { + // TestDNS_ExternalServiceToConsulCNAMELookup + name: "req type: service / question type: SRV / CNAME required: no", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "alias.service.consul.", + Qtype: dns.TypeSRV, + }, + }, + }, + configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { + fetcher.(*discovery.MockCatalogDataFetcher). + On("FetchEndpoints", mock.Anything, + &discovery.QueryPayload{ + Name: "alias", + Tenancy: discovery.QueryTenancy{}, + }, discovery.LookupTypeService). + Return([]*discovery.Result{ + { + Type: discovery.ResultTypeVirtual, + Service: &discovery.Location{Name: "alias", Address: "web.service.consul"}, + Node: &discovery.Location{Name: "web", Address: "web.service.consul"}, + }, + }, + nil).On("FetchEndpoints", mock.Anything, + &discovery.QueryPayload{ + Name: "web", + Tenancy: discovery.QueryTenancy{}, + }, discovery.LookupTypeService). + Return([]*discovery.Result{ + { + Type: discovery.ResultTypeNode, + Service: &discovery.Location{Name: "web", Address: "webnode"}, + Node: &discovery.Location{Name: "webnode", Address: "127.0.0.2"}, + }, + }, nil).On("ValidateRequest", mock.Anything, + mock.Anything).Return(nil).On("NormalizeRequest", mock.Anything) + }, + response: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Response: true, + Authoritative: true, + }, + Compress: true, + Question: []dns.Question{ + { + Name: "alias.service.consul.", + Qtype: dns.TypeSRV, + }, + }, + Answer: []dns.RR{ + &dns.SRV{ + Hdr: dns.RR_Header{ + Name: "alias.service.consul.", + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: 123, + }, + Target: "web.service.consul.", + Priority: 1, + }, + }, + Extra: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{ + Name: "web.service.consul.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 123, + }, + A: net.ParseIP("127.0.0.2"), + }, + }, + }, + }, // TODO (v2-dns): add a test to make sure only 3 records are returned // V2 Workload Lookup { @@ -1333,12 +1610,12 @@ func Test_HandleRequest(t *testing.T) { }, configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { result := &discovery.Result{ - Address: "1.2.3.4", + Node: &discovery.Location{Address: "1.2.3.4"}, Type: discovery.ResultTypeWorkload, Tenancy: discovery.ResultTenancy{}, PortName: "api", PortNumber: 5678, - Target: "foo", + Service: &discovery.Location{Name: "foo"}, } fetcher.(*discovery.MockCatalogDataFetcher). @@ -1394,10 +1671,10 @@ func Test_HandleRequest(t *testing.T) { }, configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { result := &discovery.Result{ - Address: "1.2.3.4", + Node: &discovery.Location{Address: "1.2.3.4"}, Type: discovery.ResultTypeWorkload, Tenancy: discovery.ResultTenancy{}, - Target: "foo", + Service: &discovery.Location{Name: "foo"}, } fetcher.(*discovery.MockCatalogDataFetcher). @@ -1453,14 +1730,14 @@ func Test_HandleRequest(t *testing.T) { }, configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { result := &discovery.Result{ - Address: "1.2.3.4", - Type: discovery.ResultTypeWorkload, + Node: &discovery.Location{Address: "1.2.3.4"}, + Type: discovery.ResultTypeWorkload, Tenancy: discovery.ResultTenancy{ Namespace: "bar", Partition: "baz", Datacenter: "dc3", }, - Target: "foo", + Service: &discovery.Location{Name: "foo"}, } fetcher.(*discovery.MockCatalogDataFetcher). @@ -1502,7 +1779,7 @@ func Test_HandleRequest(t *testing.T) { }, } - //testCases = append(testCases, getAdditionalTestCases(t)...) + testCases = append(testCases, getAdditionalTestCases(t)...) run := func(t *testing.T, tc HandleTestCase) { cdf := discovery.NewMockCatalogDataFetcher(t) diff --git a/agent/dns_node_lookup_test.go b/agent/dns_node_lookup_test.go index 1e507187fc..198df6f7a3 100644 --- a/agent/dns_node_lookup_test.go +++ b/agent/dns_node_lookup_test.go @@ -271,7 +271,7 @@ func TestDNS_NodeLookup_CNAME(t *testing.T) { }) defer recursor.Shutdown() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, ` recursors = ["`+recursor.Addr+`"] @@ -588,7 +588,7 @@ func TestDNS_NodeLookup_TTL(t *testing.T) { }) defer recursor.Shutdown() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, ` recursors = ["`+recursor.Addr+`"] diff --git a/agent/dns_service_lookup_test.go b/agent/dns_service_lookup_test.go index 9e021824ae..4a238b899d 100644 --- a/agent/dns_service_lookup_test.go +++ b/agent/dns_service_lookup_test.go @@ -1339,7 +1339,7 @@ func TestDNS_AltDomain_ServiceLookup_ServiceAddress_A(t *testing.T) { } } -// TODO (v2-dns): this requires a prepared query +// TODO (v2-dns): NET-7632 - Fix node and prepared query lookups when question name has a period in it func TestDNS_ServiceLookup_ServiceAddress_SRV(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short")