NET-7165 - v2 - add service questions (#20390)

* NET-7165 - v2 - add service  questions

* removing extraneous copied over code from autogen PR script.

* fixing license checking
pull/20379/head
John Murret 2024-01-29 15:33:45 -07:00 committed by GitHub
parent 3b9bb8d6f9
commit 7c6a3c83f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 1617 additions and 326 deletions

View File

@ -1106,7 +1106,14 @@ func (a *Agent) listenAndServeV2DNS() error {
if a.baseDeps.UseV2Resources() { if a.baseDeps.UseV2Resources() {
a.catalogDataFetcher = discovery.NewV2DataFetcher(a.config) a.catalogDataFetcher = discovery.NewV2DataFetcher(a.config)
} else { } else {
a.catalogDataFetcher = discovery.NewV1DataFetcher(a.config, a.AgentEnterpriseMeta(), a.RPC, a.logger.Named("catalog-data-fetcher")) a.catalogDataFetcher = discovery.NewV1DataFetcher(a.config,
a.AgentEnterpriseMeta(),
a.cache.Get,
a.RPC,
a.rpcClientHealth.ServiceNodes,
a.rpcClientConfigEntry.GetSamenessGroup,
a.TranslateServicePort,
a.logger.Named("catalog-data-fetcher"))
} }
// Generate a Query Processor with the appropriate data fetcher // Generate a Query Processor with the appropriate data fetcher

View File

@ -110,7 +110,7 @@ type Result struct {
Address string // A/AAAA/CNAME records - could be used in the Extra section. CNAME is required to handle hostname addresses in workloads & nodes. Address string // A/AAAA/CNAME records - could be used in the Extra section. CNAME is required to handle hostname addresses in workloads & nodes.
Weight uint32 // SRV queries Weight uint32 // SRV queries
Port uint32 // SRV queries Port uint32 // SRV queries
Metadata []string // Used to collect metadata into TXT Records Metadata map[string]string // Used to collect metadata into TXT Records
Type ResultType // Used to reconstruct the fqdn name of the resource Type ResultType // Used to reconstruct the fqdn name of the resource
// Used in SRV & PTR queries to point at an A/AAAA Record. // Used in SRV & PTR queries to point at an A/AAAA Record.
@ -176,6 +176,7 @@ func NewQueryProcessor(dataFetcher CatalogDataFetcher) *QueryProcessor {
} }
} }
// QueryByName is used to look up a service, node, workload, or prepared query.
func (p *QueryProcessor) QueryByName(query *Query, ctx Context) ([]*Result, error) { func (p *QueryProcessor) QueryByName(query *Query, ctx Context) ([]*Result, error) {
switch query.QueryType { switch query.QueryType {
case QueryTypeNode: case QueryTypeNode:

View File

@ -5,27 +5,33 @@ package discovery
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/armon/go-metrics"
cachetype "github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
) )
const ( const (
// TODO (v2-dns): can we move the recursion into the data fetcher? // Increment a counter when requests staler than this are served
maxRecursionLevelDefault = 3 // This field comes from the V1 DNS server and affects V1 catalog lookups staleCounterThreshold = 5 * time.Second
maxRecurseRecords = 5
) )
// v1DataFetcherDynamicConfig is used to store the dynamic configuration of the V1 data fetcher. // v1DataFetcherDynamicConfig is used to store the dynamic configuration of the V1 data fetcher.
type v1DataFetcherDynamicConfig struct { type v1DataFetcherDynamicConfig struct {
// Default request tenancy // Default request tenancy
defaultEntMeta acl.EnterpriseMeta
datacenter string datacenter string
// Catalog configuration // Catalog configuration
@ -34,25 +40,39 @@ type v1DataFetcherDynamicConfig struct {
useCache bool useCache bool
cacheMaxAge time.Duration cacheMaxAge time.Duration
onlyPassing bool onlyPassing bool
enterpriseDNSConfig EnterpriseDNSConfig
} }
// V1DataFetcher is used to fetch data from the V1 catalog. // V1DataFetcher is used to fetch data from the V1 catalog.
type V1DataFetcher struct { type V1DataFetcher struct {
// TODO(v2-dns): store this in the config.
defaultEnterpriseMeta acl.EnterpriseMeta defaultEnterpriseMeta acl.EnterpriseMeta
dynamicConfig atomic.Value dynamicConfig atomic.Value
logger hclog.Logger logger hclog.Logger
getFromCacheFunc func(ctx context.Context, t string, r cache.Request) (interface{}, cache.ResultMeta, error)
rpcFunc func(ctx context.Context, method string, args interface{}, reply interface{}) error rpcFunc func(ctx context.Context, method string, args interface{}, reply interface{}) error
rpcFuncForServiceNodes func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error)
rpcFuncForSamenessGroup func(ctx context.Context, req *structs.ConfigEntryQuery) (structs.SamenessGroupConfigEntry, cache.ResultMeta, error)
translateServicePortFunc func(dc string, port int, taggedAddresses map[string]structs.ServiceAddress) int
} }
// NewV1DataFetcher creates a new V1 data fetcher. // NewV1DataFetcher creates a new V1 data fetcher.
func NewV1DataFetcher(config *config.RuntimeConfig, func NewV1DataFetcher(config *config.RuntimeConfig,
entMeta *acl.EnterpriseMeta, entMeta *acl.EnterpriseMeta,
getFromCacheFunc func(ctx context.Context, t string, r cache.Request) (interface{}, cache.ResultMeta, error),
rpcFunc func(ctx context.Context, method string, args interface{}, reply interface{}) error, rpcFunc func(ctx context.Context, method string, args interface{}, reply interface{}) error,
rpcFuncForServiceNodes func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error),
rpcFuncForSamenessGroup func(ctx context.Context, req *structs.ConfigEntryQuery) (structs.SamenessGroupConfigEntry, cache.ResultMeta, error),
translateServicePortFunc func(dc string, port int, taggedAddresses map[string]structs.ServiceAddress) int,
logger hclog.Logger) *V1DataFetcher { logger hclog.Logger) *V1DataFetcher {
f := &V1DataFetcher{ f := &V1DataFetcher{
defaultEnterpriseMeta: *entMeta, defaultEnterpriseMeta: *entMeta,
getFromCacheFunc: getFromCacheFunc,
rpcFunc: rpcFunc, rpcFunc: rpcFunc,
rpcFuncForServiceNodes: rpcFuncForServiceNodes,
rpcFuncForSamenessGroup: rpcFuncForSamenessGroup,
translateServicePortFunc: translateServicePortFunc,
logger: logger, logger: logger,
} }
f.LoadConfig(config) f.LoadConfig(config)
@ -62,26 +82,65 @@ func NewV1DataFetcher(config *config.RuntimeConfig,
// LoadConfig loads the configuration for the V1 data fetcher. // LoadConfig loads the configuration for the V1 data fetcher.
func (f *V1DataFetcher) LoadConfig(config *config.RuntimeConfig) { func (f *V1DataFetcher) LoadConfig(config *config.RuntimeConfig) {
dynamicConfig := &v1DataFetcherDynamicConfig{ dynamicConfig := &v1DataFetcherDynamicConfig{
datacenter: config.Datacenter,
allowStale: config.DNSAllowStale, allowStale: config.DNSAllowStale,
maxStale: config.DNSMaxStale, maxStale: config.DNSMaxStale,
useCache: config.DNSUseCache, useCache: config.DNSUseCache,
cacheMaxAge: config.DNSCacheMaxAge, cacheMaxAge: config.DNSCacheMaxAge,
onlyPassing: config.DNSOnlyPassing, onlyPassing: config.DNSOnlyPassing,
enterpriseDNSConfig: GetEnterpriseDNSConfig(config),
datacenter: config.Datacenter,
// TODO (v2-dns): make this work
//defaultEntMeta: config.EnterpriseRuntimeConfig.DefaultEntMeta,
} }
f.dynamicConfig.Store(dynamicConfig) f.dynamicConfig.Store(dynamicConfig)
} }
// TODO (v2-dns): Implementation of the V1 data fetcher
// FetchNodes fetches A/AAAA/CNAME // FetchNodes fetches A/AAAA/CNAME
func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, error) { func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, error) {
return nil, nil cfg := f.dynamicConfig.Load().(*v1DataFetcherDynamicConfig)
// Make an RPC request
args := &structs.NodeSpecificRequest{
Datacenter: req.Tenancy.Datacenter,
PeerName: req.Tenancy.Peer,
Node: req.Name,
QueryOptions: structs.QueryOptions{
Token: ctx.Token,
AllowStale: cfg.allowStale,
},
EnterpriseMeta: req.Tenancy.EnterpriseMeta,
}
out, err := f.fetchNode(cfg, args)
if err != nil {
return nil, fmt.Errorf("failed rpc request: %w", err)
}
// If we have no out.NodeServices.Nodeaddress, return not found!
if out.NodeServices == nil {
return nil, errors.New("no nodes found")
}
results := make([]*Result, 0, 1)
node := out.NodeServices.Node
results = append(results, &Result{
Address: node.Address,
Type: ResultTypeNode,
Metadata: node.Meta,
Target: node.Node,
})
return results, nil
} }
// FetchEndpoints fetches records for A/AAAA/CNAME or SRV requests for services // FetchEndpoints fetches records for A/AAAA/CNAME or SRV requests for services
func (f *V1DataFetcher) FetchEndpoints(ctx Context, req *QueryPayload, lookupType LookupType) ([]*Result, error) { func (f *V1DataFetcher) FetchEndpoints(ctx Context, req *QueryPayload, lookupType LookupType) ([]*Result, error) {
return nil, nil f.logger.Debug(fmt.Sprintf("FetchEndpoints - req: %+v / lookupType: %+v", req, lookupType))
cfg := f.dynamicConfig.Load().(*v1DataFetcherDynamicConfig)
if lookupType == LookupTypeService {
return f.fetchService(ctx, req, cfg)
}
return nil, errors.New(fmt.Sprintf("unsupported lookup type: %s", lookupType))
} }
// FetchVirtualIP fetches A/AAAA records for virtual IPs // FetchVirtualIP fetches A/AAAA records for virtual IPs
@ -193,3 +252,182 @@ func (f *V1DataFetcher) FetchWorkload(ctx Context, req *QueryPayload) (*Result,
func (f *V1DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error) { func (f *V1DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error) {
return nil, nil return nil, nil
} }
// fetchNode is used to look up a node in the Consul catalog within NodeServices.
// If the config is set to UseCache, it will get the record from the agent cache.
func (f *V1DataFetcher) fetchNode(cfg *v1DataFetcherDynamicConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) {
var out structs.IndexedNodeServices
useCache := cfg.useCache
RPC:
if useCache {
raw, _, err := f.getFromCacheFunc(context.TODO(), cachetype.NodeServicesName, args)
if err != nil {
return nil, err
}
reply, ok := raw.(*structs.IndexedNodeServices)
if !ok {
// This should never happen, but we want to protect against panics
return nil, fmt.Errorf("internal error: response type not correct")
}
out = *reply
} else {
if err := f.rpcFunc(context.Background(), "Catalog.NodeServices", &args, &out); err != nil {
return nil, err
}
}
// Verify that request is not too stale, redo the request
if args.AllowStale {
if out.LastContact > cfg.maxStale {
args.AllowStale = false
useCache = false
f.logger.Warn("Query results too stale, re-requesting")
goto RPC
} else if out.LastContact > staleCounterThreshold {
metrics.IncrCounter([]string{"dns", "stale_queries"}, 1)
}
}
return &out, nil
}
func (f *V1DataFetcher) fetchService(ctx Context, req *QueryPayload, cfg *v1DataFetcherDynamicConfig) ([]*Result, error) {
f.logger.Debug("fetchService", "req", req)
if req.Tenancy.SamenessGroup == "" {
return f.fetchServiceBasedOnTenancy(ctx, req, cfg)
}
return f.fetchServiceFromSamenessGroup(ctx, req, cfg)
}
// fetchServiceBasedOnTenancy is used to look up a service in the Consul catalog based on its tenancy or default tenancy.
func (f *V1DataFetcher) fetchServiceBasedOnTenancy(ctx Context, req *QueryPayload, cfg *v1DataFetcherDynamicConfig) ([]*Result, error) {
f.logger.Debug(fmt.Sprintf("fetchServiceBasedOnTenancy - req: %+v", req))
if req.Tenancy.SamenessGroup != "" {
return nil, errors.New("sameness groups are not allowed for service lookups based on tenancy")
}
datacenter := req.Tenancy.Datacenter
if req.Tenancy.Peer != "" {
datacenter = ""
}
serviceTags := []string{}
if req.Tag != "" {
serviceTags = []string{req.Tag}
}
args := structs.ServiceSpecificRequest{
PeerName: req.Tenancy.Peer,
Connect: false,
Ingress: false,
Datacenter: datacenter,
ServiceName: req.Name,
ServiceTags: serviceTags,
TagFilter: req.Tag != "",
QueryOptions: structs.QueryOptions{
Token: ctx.Token,
AllowStale: cfg.allowStale,
MaxAge: cfg.cacheMaxAge,
UseCache: cfg.useCache,
MaxStaleDuration: cfg.maxStale,
},
EnterpriseMeta: req.Tenancy.EnterpriseMeta,
}
out, _, err := f.rpcFuncForServiceNodes(context.TODO(), args)
if err != nil {
return nil, err
}
// Filter out any service nodes due to health checks
// We copy the slice to avoid modifying the result if it comes from the cache
nodes := make(structs.CheckServiceNodes, len(out.Nodes))
copy(nodes, out.Nodes)
out.Nodes = nodes.Filter(cfg.onlyPassing)
if err != nil {
return nil, fmt.Errorf("rpc request failed: %w", err)
}
// If we have no nodes, return not found!
if len(out.Nodes) == 0 {
return nil, ErrNoData
}
// Perform a random shuffle
out.Nodes.Shuffle()
results := make([]*Result, 0, len(out.Nodes))
for _, node := range out.Nodes {
target := node.Service.Address
resultType := ResultTypeService
// TODO (v2-dns): IMPORTANT!!!!: this needs to be revisited in how dns v1 utilizes
// the nodeaddress when the service address is an empty string. Need to figure out
// if this can be removed and dns recursion and process can work with only the
// address set to the node.address and the target set to the service.address.
// We may have to look at modifying the discovery result if more metadata is needed to send along.
if target == "" {
target = node.Node.Node
resultType = ResultTypeNode
}
results = append(results, &Result{
Address: node.Node.Address,
Type: resultType,
Target: target,
Weight: uint32(findWeight(node)),
Port: uint32(f.translateServicePortFunc(node.Node.Datacenter, node.Service.Port, node.Service.TaggedAddresses)),
Metadata: node.Node.Meta,
Tenancy: ResultTenancy{
EnterpriseMeta: cfg.defaultEntMeta,
Datacenter: cfg.datacenter,
},
})
}
return results, nil
}
// findWeight returns the weight of a service node.
func findWeight(node structs.CheckServiceNode) int {
// By default, when only_passing is false, warning and passing nodes are returned
// Those values will be used if using a client with support while server has no
// support for weights
weightPassing := 1
weightWarning := 1
if node.Service.Weights != nil {
weightPassing = node.Service.Weights.Passing
weightWarning = node.Service.Weights.Warning
}
serviceChecks := make(api.HealthChecks, 0, len(node.Checks))
for _, c := range node.Checks {
if c.ServiceName == node.Service.Service || c.ServiceName == "" {
healthCheck := &api.HealthCheck{
Node: c.Node,
CheckID: string(c.CheckID),
Name: c.Name,
Status: c.Status,
Notes: c.Notes,
Output: c.Output,
ServiceID: c.ServiceID,
ServiceName: c.ServiceName,
ServiceTags: c.ServiceTags,
}
serviceChecks = append(serviceChecks, healthCheck)
}
}
status := serviceChecks.AggregatedStatus()
switch status {
case api.HealthWarning:
return weightWarning
case api.HealthPassing:
return weightPassing
case api.HealthMaint:
// Not used in theory
return 0
case api.HealthCritical:
// Should not happen since already filtered
return 0
default:
// When non-standard status, return 1
return 1
}
}

View File

@ -0,0 +1,20 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !consulent
package discovery
import (
"errors"
"fmt"
)
// fetchServiceFromSamenessGroup fetches a service from a sameness group.
func (f *V1DataFetcher) fetchServiceFromSamenessGroup(ctx Context, req *QueryPayload, cfg *v1DataFetcherDynamicConfig) ([]*Result, error) {
f.logger.Debug(fmt.Sprintf("fetchServiceFromSamenessGroup - req: %+v", req))
if req.Tenancy.SamenessGroup == "" {
return nil, errors.New("sameness groups must be provided for service lookups")
}
return f.fetchServiceBasedOnTenancy(ctx, req, cfg)
}

View File

@ -4,10 +4,13 @@
package discovery package discovery
import ( import (
"context"
"errors" "errors"
"testing" "testing"
"time" "time"
"github.com/hashicorp/consul/agent/cache"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -96,7 +99,19 @@ func Test_FetchVirtualIP(t *testing.T) {
*reply = tc.expectedResult.Address *reply = tc.expectedResult.Address
} }
}) })
df := NewV1DataFetcher(rc, acl.DefaultEnterpriseMeta(), mockRPC.RPC, logger) // TODO (v2-dns): mock these properly
translateServicePortFunc := func(dc string, port int, taggedAddresses map[string]structs.ServiceAddress) int { return 0 }
rpcFuncForServiceNodes := func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) {
return structs.IndexedCheckServiceNodes{}, cache.ResultMeta{}, nil
}
rpcFuncForSamenessGroup := func(ctx context.Context, req *structs.ConfigEntryQuery) (structs.SamenessGroupConfigEntry, cache.ResultMeta, error) {
return structs.SamenessGroupConfigEntry{}, cache.ResultMeta{}, nil
}
getFromCacheFunc := func(ctx context.Context, t string, r cache.Request) (interface{}, cache.ResultMeta, error) {
return nil, cache.ResultMeta{}, nil
}
df := NewV1DataFetcher(rc, acl.DefaultEnterpriseMeta(), getFromCacheFunc, mockRPC.RPC, rpcFuncForServiceNodes, rpcFuncForSamenessGroup, translateServicePortFunc, logger)
result, err := df.FetchVirtualIP(tc.context, tc.queryPayload) result, err := df.FetchVirtualIP(tc.context, tc.queryPayload)
require.Equal(t, tc.expectedErr, err) require.Equal(t, tc.expectedErr, err)

View File

@ -0,0 +1,61 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package discovery
import "github.com/hashicorp/consul/acl"
// QueryLocality is the locality parsed from a DNS query.
type QueryLocality struct {
// Datacenter is the datacenter parsed from a label that has an explicit datacenter part.
// Example query: <service>.virtual.<namespace>.ns.<partition>.ap.<datacenter>.dc.consul
Datacenter string
// Peer is the peer name parsed from a label that has explicit parts.
// Example query: <service>.virtual.<namespace>.ns.<peer>.peer.<partition>.ap.consul
Peer string
// PeerOrDatacenter is parsed from DNS queries where the datacenter and peer name are
// specified in the same query part.
// Example query: <service>.virtual.<peerOrDatacenter>.consul
//
// Note that this field should only be a "peer" for virtual queries, since virtual IPs should
// not be shared between datacenters. In all other cases, it should be considered a DC.
PeerOrDatacenter string
acl.EnterpriseMeta
}
// EffectiveDatacenter returns the datacenter parsed from a query, or a default
// value if none is specified.
func (l QueryLocality) EffectiveDatacenter(defaultDC string) string {
// Prefer the value parsed from a query with explicit parts: <namespace>.ns.<partition>.ap.<datacenter>.dc
if l.Datacenter != "" {
return l.Datacenter
}
// Fall back to the ambiguously parsed DC or Peer.
if l.PeerOrDatacenter != "" {
return l.PeerOrDatacenter
}
// If all are empty, use a default value.
return defaultDC
}
// GetQueryTenancyBasedOnLocality returns a discovery.QueryTenancy from a DNS message.
func GetQueryTenancyBasedOnLocality(locality QueryLocality, defaultDatacenter string) (QueryTenancy, error) {
datacenter := locality.EffectiveDatacenter(defaultDatacenter)
// Only one of dc or peer can be used.
if locality.Peer != "" {
datacenter = ""
}
return QueryTenancy{
EnterpriseMeta: locality.EnterpriseMeta,
// The datacenter of the request is not specified because cross-datacenter virtual IP
// queries are not supported. This guard rail is in place because virtual IPs are allocated
// within a DC, therefore their uniqueness is not guaranteed globally.
Peer: locality.Peer,
Datacenter: datacenter,
SamenessGroup: "", // this should be nil since the single locality was directly used to configure tenancy.
}, nil
}

View File

@ -0,0 +1,57 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !consulent
package discovery
import (
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/config"
)
// ParseLocality can parse peer name or datacenter from a DNS query's labels.
// Peer name is parsed from the same query part that datacenter is, so given this ambiguity
// we parse a "peerOrDatacenter". The caller or RPC handler are responsible for disambiguating.
func ParseLocality(labels []string, defaultEnterpriseMeta acl.EnterpriseMeta, _ EnterpriseDNSConfig) (QueryLocality, bool) {
locality := QueryLocality{
EnterpriseMeta: defaultEnterpriseMeta,
}
switch len(labels) {
case 2, 4:
// Support the following formats:
// - [.<datacenter>.dc]
// - [.<peer>.peer]
for i := 0; i < len(labels); i += 2 {
switch labels[i+1] {
case "dc":
locality.Datacenter = labels[i]
case "peer":
locality.Peer = labels[i]
default:
return QueryLocality{}, false
}
}
// Return error when both datacenter and peer are specified.
if locality.Datacenter != "" && locality.Peer != "" {
return QueryLocality{}, false
}
return locality, true
case 1:
return QueryLocality{PeerOrDatacenter: labels[0]}, true
case 0:
return QueryLocality{}, true
}
return QueryLocality{}, false
}
// EnterpriseDNSConfig is the configuration for enterprise DNS.
type EnterpriseDNSConfig struct{}
// GetEnterpriseDNSConfig returns the enterprise DNS configuration.
func GetEnterpriseDNSConfig(conf *config.RuntimeConfig) EnterpriseDNSConfig {
return EnterpriseDNSConfig{}
}

View File

@ -0,0 +1,60 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !consulent
package discovery
import (
"github.com/hashicorp/consul/acl"
)
func getTestCases() []testCaseParseLocality {
testCases := []testCaseParseLocality{
{
name: "test [.<datacenter>.dc]",
labels: []string{"test-dc", "dc"},
enterpriseDNSConfig: EnterpriseDNSConfig{},
expectedResult: QueryLocality{
EnterpriseMeta: acl.EnterpriseMeta{},
Datacenter: "test-dc",
},
expectedOK: true,
},
{
name: "test [.<peer>.peer]",
labels: []string{"test-peer", "peer"},
enterpriseDNSConfig: EnterpriseDNSConfig{},
expectedResult: QueryLocality{
EnterpriseMeta: acl.EnterpriseMeta{},
Peer: "test-peer",
},
expectedOK: true,
},
{
name: "test 1 label",
labels: []string{"test-peer"},
enterpriseDNSConfig: EnterpriseDNSConfig{},
expectedResult: QueryLocality{
EnterpriseMeta: acl.EnterpriseMeta{},
PeerOrDatacenter: "test-peer",
},
expectedOK: true,
},
{
name: "test 0 labels",
labels: []string{},
enterpriseDNSConfig: EnterpriseDNSConfig{},
expectedResult: QueryLocality{},
expectedOK: true,
},
{
name: "test 3 labels returns not found",
labels: []string{"test-dc", "dc", "test-blah"},
enterpriseDNSConfig: EnterpriseDNSConfig{},
expectedResult: QueryLocality{},
expectedOK: false,
},
}
return testCases
}

View File

@ -0,0 +1,73 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package discovery
import (
"testing"
"github.com/hashicorp/consul/acl"
"github.com/stretchr/testify/require"
)
type testCaseParseLocality struct {
name string
labels []string
defaultMeta acl.EnterpriseMeta
enterpriseDNSConfig EnterpriseDNSConfig
expectedResult QueryLocality
expectedOK bool
}
func Test_parseLocality(t *testing.T) {
testCases := getTestCases()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actualResult, actualOK := ParseLocality(tc.labels, tc.defaultMeta, tc.enterpriseDNSConfig)
require.Equal(t, tc.expectedOK, actualOK)
require.Equal(t, tc.expectedResult, actualResult)
})
}
}
func Test_effectiveDatacenter(t *testing.T) {
type testCase struct {
name string
QueryLocality QueryLocality
defaultDC string
expected string
}
testCases := []testCase{
{
name: "return Datacenter first",
QueryLocality: QueryLocality{
Datacenter: "test-dc",
PeerOrDatacenter: "test-peer",
},
defaultDC: "default-dc",
expected: "test-dc",
},
{
name: "return PeerOrDatacenter second",
QueryLocality: QueryLocality{
PeerOrDatacenter: "test-peer",
},
defaultDC: "default-dc",
expected: "test-peer",
},
{
name: "return defaultDC as fallback",
QueryLocality: QueryLocality{},
defaultDC: "default-dc",
expected: "default-dc",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := tc.QueryLocality.EffectiveDatacenter(tc.defaultDC)
require.Equal(t, tc.expected, got)
})
}
}

87
agent/dns/dns_address.go Normal file
View File

@ -0,0 +1,87 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"github.com/miekg/dns"
"net"
"strings"
)
func newDNSAddress(addr string) *dnsAddress {
a := &dnsAddress{}
a.SetAddress(addr)
return a
}
// dnsAddress is a wrapper around a string that represents a DNS address and
// provides helper methods for determining whether it is an IP or FQDN and
// whether it is internal or external to the domain.
type dnsAddress struct {
addr string
// store an IP so helpers don't have to parse it multiple times
ip net.IP
}
// SetAddress sets the address field and the ip field if the string is an IP.
func (a *dnsAddress) SetAddress(addr string) {
a.addr = addr
a.ip = net.ParseIP(addr)
}
// IP returns the IP address if the address is an IP.
func (a *dnsAddress) IP() net.IP {
return a.ip
}
// IsIP returns true if the address is an IP.
func (a *dnsAddress) IsIP() bool {
return a.IP() != nil
}
// IsIPV4 returns true if the address is an IPv4 address.
func (a *dnsAddress) IsIPV4() bool {
if a.IP() == nil {
return false
}
return a.IP().To4() != nil
}
// FQDN returns the FQDN if the address is not an IP.
func (a *dnsAddress) FQDN() string {
if !a.IsEmptyString() && !a.IsIP() {
return dns.Fqdn(a.addr)
}
return ""
}
// IsFQDN returns true if the address is a FQDN and not an IP.
func (a *dnsAddress) IsFQDN() bool {
return !a.IsEmptyString() && !a.IsIP() && dns.IsFqdn(a.FQDN())
}
// String returns the address as a string.
func (a *dnsAddress) String() string {
return a.addr
}
// IsEmptyString returns true if the address is an empty string.
func (a *dnsAddress) IsEmptyString() bool {
return a.addr == ""
}
// IsInternalFQDN returns true if the address is a FQDN and is internal to the domain.
func (a *dnsAddress) IsInternalFQDN(domain string) bool {
return !a.IsIP() && a.IsFQDN() && strings.HasSuffix(a.FQDN(), domain)
}
// IsInternalFQDNOrIP returns true if the address is an IP or a FQDN and is internal to the domain.
func (a *dnsAddress) IsInternalFQDNOrIP(domain string) bool {
return a.IsIP() || a.IsInternalFQDN(domain)
}
// IsExternalFQDN returns true if the address is a FQDN and is external to the domain.
func (a *dnsAddress) IsExternalFQDN(domain string) bool {
return !a.IsIP() && a.IsFQDN() && !strings.HasSuffix(a.FQDN(), domain)
}

View File

@ -0,0 +1,154 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"github.com/stretchr/testify/assert"
"testing"
)
func Test_dnsAddress(t *testing.T) {
const domain = "consul."
type expectedResults struct {
isIp bool
stringResult string
fqdn string
isFQDN bool
isEmptyString bool
isExternalFQDN bool
isInternalFQDN bool
isInternalFQDNOrIP bool
}
type testCase struct {
name string
input string
expectedResults expectedResults
}
testCases := []testCase{
{
name: "empty string",
input: "",
expectedResults: expectedResults{
isIp: false,
stringResult: "",
fqdn: "",
isFQDN: false,
isEmptyString: true,
isExternalFQDN: false,
isInternalFQDN: false,
isInternalFQDNOrIP: false,
},
},
{
name: "ipv4 address",
input: "127.0.0.1",
expectedResults: expectedResults{
isIp: true,
stringResult: "127.0.0.1",
fqdn: "",
isFQDN: false,
isEmptyString: false,
isExternalFQDN: false,
isInternalFQDN: false,
isInternalFQDNOrIP: true,
},
},
{
name: "ipv6 address",
input: "2001:db8:1:2:cafe::1337",
expectedResults: expectedResults{
isIp: true,
stringResult: "2001:db8:1:2:cafe::1337",
fqdn: "",
isFQDN: false,
isEmptyString: false,
isExternalFQDN: false,
isInternalFQDN: false,
isInternalFQDNOrIP: true,
},
},
{
name: "internal FQDN without trailing period",
input: "web.service.consul",
expectedResults: expectedResults{
isIp: false,
stringResult: "web.service.consul",
fqdn: "web.service.consul.",
isFQDN: true,
isEmptyString: false,
isExternalFQDN: false,
isInternalFQDN: true,
isInternalFQDNOrIP: true,
},
},
{
name: "internal FQDN with period",
input: "web.service.consul.",
expectedResults: expectedResults{
isIp: false,
stringResult: "web.service.consul.",
fqdn: "web.service.consul.",
isFQDN: true,
isEmptyString: false,
isExternalFQDN: false,
isInternalFQDN: true,
isInternalFQDNOrIP: true,
},
},
{
name: "external FQDN without trailing period",
input: "web.service.vault",
expectedResults: expectedResults{
isIp: false,
stringResult: "web.service.vault",
fqdn: "web.service.vault.",
isFQDN: true,
isEmptyString: false,
isExternalFQDN: true,
isInternalFQDN: false,
isInternalFQDNOrIP: false,
},
},
{
name: "external FQDN with trailing period",
input: "web.service.vault.",
expectedResults: expectedResults{
isIp: false,
stringResult: "web.service.vault.",
fqdn: "web.service.vault.",
isFQDN: true,
isEmptyString: false,
isExternalFQDN: true,
isInternalFQDN: false,
isInternalFQDNOrIP: false,
},
},
{
name: "another external FQDN",
input: "www.google.com",
expectedResults: expectedResults{
isIp: false,
stringResult: "www.google.com",
fqdn: "www.google.com.",
isFQDN: true,
isEmptyString: false,
isExternalFQDN: true,
isInternalFQDN: false,
isInternalFQDNOrIP: false,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
dnsAddress := newDNSAddress(tc.input)
assert.Equal(t, tc.expectedResults.isIp, dnsAddress.IsIP())
assert.Equal(t, tc.expectedResults.stringResult, dnsAddress.String())
assert.Equal(t, tc.expectedResults.isFQDN, dnsAddress.IsFQDN())
assert.Equal(t, tc.expectedResults.isEmptyString, dnsAddress.IsEmptyString())
assert.Equal(t, tc.expectedResults.isExternalFQDN, dnsAddress.IsExternalFQDN(domain))
assert.Equal(t, tc.expectedResults.isInternalFQDN, dnsAddress.IsInternalFQDN(domain))
assert.Equal(t, tc.expectedResults.isInternalFQDNOrIP, dnsAddress.IsInternalFQDNOrIP(domain))
})
}
}

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.20.0. DO NOT EDIT. // Code generated by mockery v2.32.4. DO NOT EDIT.
package dns package dns
@ -40,13 +40,12 @@ func (_m *mockDnsRecursor) handle(req *miekgdns.Msg, cfgCtx *RouterDynamicConfig
return r0, r1 return r0, r1
} }
type mockConstructorTestingTnewMockDnsRecursor interface { // newMockDnsRecursor creates a new instance of mockDnsRecursor. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func newMockDnsRecursor(t interface {
mock.TestingT mock.TestingT
Cleanup(func()) Cleanup(func())
} }) *mockDnsRecursor {
// newMockDnsRecursor creates a new instance of mockDnsRecursor. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func newMockDnsRecursor(t mockConstructorTestingTnewMockDnsRecursor) *mockDnsRecursor {
mock := &mockDnsRecursor{} mock := &mockDnsRecursor{}
mock.Mock.Test(t) mock.Mock.Test(t)

View File

@ -8,6 +8,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"regexp"
"strings"
"sync/atomic" "sync/atomic"
"time" "time"
@ -31,12 +33,16 @@ const (
suffixFailover = "failover." suffixFailover = "failover."
suffixNoFailover = "no-failover." suffixNoFailover = "no-failover."
maxRecursionLevelDefault = 3 // This field comes from the V1 DNS server and affects V1 catalog lookups
maxRecurseRecords = 5
) )
var ( var (
errInvalidQuestion = fmt.Errorf("invalid question") errInvalidQuestion = fmt.Errorf("invalid question")
errNameNotFound = fmt.Errorf("name not found") errNameNotFound = fmt.Errorf("name not found")
errRecursionFailed = fmt.Errorf("recursion failed") errRecursionFailed = fmt.Errorf("recursion failed")
trailingSpacesRE = regexp.MustCompile(" +$")
) )
// TODO (v2-dns): metrics // TODO (v2-dns): metrics
@ -59,7 +65,25 @@ type RouterDynamicConfig struct {
TTLStrict map[string]time.Duration TTLStrict map[string]time.Duration
UDPAnswerLimit int UDPAnswerLimit int
enterpriseDNSConfig discovery.EnterpriseDNSConfig
}
// GetTTLForService Find the TTL for a given service.
// return ttl, true if found, 0, false otherwise
func (cfg *RouterDynamicConfig) GetTTLForService(service string) (time.Duration, bool) {
if cfg.TTLStrict != nil {
ttl, ok := cfg.TTLStrict[service]
if ok {
return ttl, true
}
}
if cfg.TTLRadix != nil {
_, ttlRaw, ok := cfg.TTLRadix.LongestPrefix(service)
if ok {
return ttlRaw.(time.Duration), true
}
}
return 0, false
} }
type SOAConfig struct { type SOAConfig struct {
@ -135,6 +159,13 @@ func NewRouter(cfg Config) (*Router, error) {
// HandleRequest is used to process an individual DNS request. It returns a message in success or fail cases. // HandleRequest is used to process an individual DNS request. It returns a message in success or fail cases.
func (r *Router) HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAddress net.Addr) *dns.Msg { func (r *Router) HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAddress net.Addr) *dns.Msg {
return r.handleRequestRecursively(req, reqCtx, remoteAddress, maxRecursionLevelDefault)
}
// handleRequestRecursively is used to process an individual DNS request. It will recurse as needed
// a maximum number of times and returns a message in success or fail cases.
func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx discovery.Context,
remoteAddress net.Addr, maxRecursionLevel int) *dns.Msg {
configCtx := r.dynamicConfig.Load().(*RouterDynamicConfig) configCtx := r.dynamicConfig.Load().(*RouterDynamicConfig)
err := validateAndNormalizeRequest(req) err := validateAndNormalizeRequest(req)
@ -165,7 +196,7 @@ func (r *Router) HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAdd
} }
reqType := parseRequestType(req) reqType := parseRequestType(req)
results, err := r.getQueryResults(req, reqCtx, reqType, configCtx) results, query, err := r.getQueryResults(req, reqCtx, reqType, configCtx)
switch { switch {
case errors.Is(err, errNameNotFound): case errors.Is(err, errNameNotFound):
r.logger.Error("name not found", "name", req.Question[0].Name) r.logger.Error("name not found", "name", req.Question[0].Name)
@ -185,7 +216,7 @@ func (r *Router) HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAdd
// This needs the question information because it affects the serialization format. // This needs the question information because it affects the serialization format.
// e.g., the Consul service has the same "results" for both NS and A/AAAA queries, but the serialization differs. // e.g., the Consul service has the same "results" for both NS and A/AAAA queries, but the serialization differs.
resp, err := r.serializeQueryResults(req, results, configCtx, responseDomain) resp, err := r.serializeQueryResults(req, reqCtx, query, results, configCtx, responseDomain, remoteAddress, maxRecursionLevel)
if err != nil { if err != nil {
r.logger.Error("error serializing DNS results", "error", err) r.logger.Error("error serializing DNS results", "error", err)
return createServerFailureResponse(req, configCtx, false) return createServerFailureResponse(req, configCtx, false)
@ -193,8 +224,27 @@ func (r *Router) HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAdd
return resp return resp
} }
// getTTLForResult returns the TTL for a given result.
func getTTLForResult(name string, query *discovery.Query, cfg *RouterDynamicConfig) uint32 {
switch {
// TODO (v2-dns): currently have to do this related to the results type being changed to node whe
// the v1 data fetcher encounters a blank service address and uses the node address instead.
// we will revisiting this when look at modifying the discovery result struct to
// possibly include additional metadata like the node address.
case query != nil && query.QueryType == discovery.QueryTypeService:
ttl, ok := cfg.GetTTLForService(name)
if ok {
return uint32(ttl / time.Second)
}
fallthrough
default:
return uint32(cfg.NodeTTL / time.Second)
}
}
// getQueryResults returns a discovery.Result from a DNS message. // getQueryResults returns a discovery.Result from a DNS message.
func (r *Router) getQueryResults(req *dns.Msg, reqCtx discovery.Context, reqType requestType, cfgCtx *RouterDynamicConfig) ([]*discovery.Result, error) { func (r *Router) getQueryResults(req *dns.Msg, reqCtx discovery.Context, reqType requestType, cfg *RouterDynamicConfig) ([]*discovery.Result, *discovery.Query, error) {
var query *discovery.Query
switch reqType { switch reqType {
case requestTypeConsul: case requestTypeConsul:
// This is a special case of discovery.QueryByName where we know that we need to query the consul service // This is a special case of discovery.QueryByName where we know that we need to query the consul service
@ -206,25 +256,38 @@ func (r *Router) getQueryResults(req *dns.Msg, reqCtx discovery.Context, reqType
}, },
Limit: 3, // TODO (v2-dns): need to thread this through to the backend and make sure we shuffle the results Limit: 3, // TODO (v2-dns): need to thread this through to the backend and make sure we shuffle the results
} }
return r.processor.QueryByName(query, reqCtx)
results, err := r.processor.QueryByName(query, reqCtx)
return results, query, err
case requestTypeName: case requestTypeName:
query, err := buildQueryFromDNSMessage(req, r.domain, r.altDomain, cfgCtx, r.defaultEntMeta) query, err := buildQueryFromDNSMessage(req, r.domain, r.altDomain, cfg, r.defaultEntMeta, r.datacenter)
if err != nil { if err != nil {
r.logger.Error("error building discovery query from DNS request", "error", err) r.logger.Error("error building discovery query from DNS request", "error", err)
return nil, err return nil, query, err
} }
return r.processor.QueryByName(query, reqCtx) results, err := r.processor.QueryByName(query, reqCtx)
if err != nil {
r.logger.Error("error processing discovery query", "error", err)
return nil, query, err
}
return results, query, nil
case requestTypeIP: case requestTypeIP:
ip := dnsutil.IPFromARPA(req.Question[0].Name) ip := dnsutil.IPFromARPA(req.Question[0].Name)
if ip == nil { if ip == nil {
r.logger.Error("error building IP from DNS request", "name", req.Question[0].Name) r.logger.Error("error building IP from DNS request", "name", req.Question[0].Name)
return nil, errNameNotFound return nil, nil, errNameNotFound
} }
return r.processor.QueryByIP(ip, reqCtx) results, err := r.processor.QueryByIP(ip, reqCtx)
return results, query, err
case requestTypeAddress: case requestTypeAddress:
return buildAddressResults(req) results, err := buildAddressResults(req)
if err != nil {
r.logger.Error("error processing discovery query", "error", err)
return nil, query, err
} }
return nil, errors.New("invalid request type") return results, query, nil
}
return nil, query, errors.New("invalid request type")
} }
// ServeDNS implements the miekg/dns.Handler interface. // ServeDNS implements the miekg/dns.Handler interface.
@ -304,23 +367,99 @@ func parseRequestType(req *dns.Msg) requestType {
} }
// serializeQueryResults converts a discovery.Result into a DNS message. // serializeQueryResults converts a discovery.Result into a DNS message.
func (r *Router) serializeQueryResults(req *dns.Msg, results []*discovery.Result, cfg *RouterDynamicConfig, responseDomain string) (*dns.Msg, error) { func (r *Router) serializeQueryResults(req *dns.Msg, reqCtx discovery.Context,
query *discovery.Query, results []*discovery.Result, cfg *RouterDynamicConfig,
responseDomain string, remoteAddress net.Addr, maxRecursionLevel int) (*dns.Msg, error) {
resp := new(dns.Msg) resp := new(dns.Msg)
resp.SetReply(req) resp.SetReply(req)
resp.Compress = !cfg.DisableCompression resp.Compress = !cfg.DisableCompression
resp.Authoritative = true resp.Authoritative = true
resp.RecursionAvailable = canRecurse(cfg) resp.RecursionAvailable = canRecurse(cfg)
qType := req.Question[0].Qtype
reqType := parseRequestType(req)
// Always add the SOA record if requested.
switch {
case qType == dns.TypeSOA:
resp.Answer = append(resp.Answer, makeSOARecord(responseDomain, cfg))
for _, result := range results {
ans, ex, ns := r.getAnswerExtraAndNs(result, req, reqCtx, query, cfg, responseDomain, remoteAddress, maxRecursionLevel)
resp.Answer = append(resp.Answer, ans...)
resp.Extra = append(resp.Extra, ex...)
resp.Ns = append(resp.Ns, ns...)
}
case qType == dns.TypeSRV, reqType == requestTypeAddress:
for _, result := range results {
ans, ex, ns := r.getAnswerExtraAndNs(result, req, reqCtx, query, cfg, responseDomain, remoteAddress, maxRecursionLevel)
resp.Answer = append(resp.Answer, ans...)
resp.Extra = append(resp.Extra, ex...)
resp.Ns = append(resp.Ns, ns...)
}
default:
// default will send it to where it does some de-duping while it calls getAnswerExtraAndNs and recurses.
r.appendResultsToDNSResponse(req, reqCtx, query, resp, results, cfg, responseDomain, remoteAddress, maxRecursionLevel)
}
return resp, nil
}
// appendResultsToDNSResponse builds dns message from the discovery results and
// appends them to the dns response.
func (r *Router) appendResultsToDNSResponse(req *dns.Msg, reqCtx discovery.Context,
query *discovery.Query, resp *dns.Msg, results []*discovery.Result, cfg *RouterDynamicConfig,
responseDomain string, remoteAddress net.Addr, maxRecursionLevel int) {
// Always add the SOA record if requested. // Always add the SOA record if requested.
if req.Question[0].Qtype == dns.TypeSOA { if req.Question[0].Qtype == dns.TypeSOA {
resp.Answer = append(resp.Answer, makeSOARecord(responseDomain, cfg)) resp.Answer = append(resp.Answer, makeSOARecord(responseDomain, cfg))
} }
handled := make(map[string]struct{})
var answerCNAME []dns.RR = nil
count := 0
for _, result := range results { for _, result := range results {
appendResultToDNSResponse(result, req, resp, responseDomain, cfg) // Add the node record
had_answer := false
ans, extra, _ := r.getAnswerExtraAndNs(result, req, reqCtx, query, cfg, responseDomain, remoteAddress, maxRecursionLevel)
resp.Extra = append(resp.Extra, extra...)
if len(ans) == 0 {
continue
} }
return resp, nil // Avoid duplicate entries, possible if a node has
// the same service on multiple ports, etc.
if _, ok := handled[ans[0].String()]; ok {
continue
}
handled[ans[0].String()] = struct{}{}
switch ans[0].(type) {
case *dns.CNAME:
// keep track of the first CNAME + associated RRs but don't add to the resp.Answer yet
// this will only be added if no non-CNAME RRs are found
if len(answerCNAME) == 0 {
answerCNAME = ans
}
default:
resp.Answer = append(resp.Answer, ans...)
had_answer = true
}
if had_answer {
count++
if count == cfg.ARecordLimit {
// We stop only if greater than 0 or we reached the limit
return
}
}
}
if len(resp.Answer) == 0 && len(answerCNAME) > 0 {
resp.Answer = answerCNAME
}
} }
// defaultAgentDNSRequestContext returns a default request context based on the agent's config. // defaultAgentDNSRequestContext returns a default request context based on the agent's config.
@ -332,6 +471,46 @@ func (r *Router) defaultAgentDNSRequestContext() discovery.Context {
} }
} }
// resolveCNAME is used to recursively resolve CNAME records
func (r *Router) resolveCNAME(cfg *RouterDynamicConfig, name string, reqCtx discovery.Context,
remoteAddress net.Addr, maxRecursionLevel int) []dns.RR {
// If the CNAME record points to a Consul address, resolve it internally
// Convert query to lowercase because DNS is case insensitive; d.domain and
// d.altDomain are already converted
if ln := strings.ToLower(name); strings.HasSuffix(ln, "."+r.domain) || strings.HasSuffix(ln, "."+r.altDomain) {
if maxRecursionLevel < 1 {
//d.logger.Error("Infinite recursion detected for name, won't perform any CNAME resolution.", "name", name)
return nil
}
req := &dns.Msg{}
req.SetQuestion(name, dns.TypeANY)
// TODO: handle error response
resp := r.handleRequestRecursively(req, reqCtx, nil, maxRecursionLevel-1)
return resp.Answer
}
// Do nothing if we don't have a recursor
if !canRecurse(cfg) {
return nil
}
// Ask for any A records
m := new(dns.Msg)
m.SetQuestion(name, dns.TypeA)
// Make a DNS lookup request
recursorResponse, err := r.recursor.handle(m, cfg, remoteAddress)
if err == nil {
return recursorResponse.Answer
}
r.logger.Error("all resolvers failed for name", "name", name)
return nil
}
// validateAndNormalizeRequest validates the DNS request and normalizes the request name. // validateAndNormalizeRequest validates the DNS request and normalizes the request name.
func validateAndNormalizeRequest(req *dns.Msg) error { func validateAndNormalizeRequest(req *dns.Msg) error {
// like upstream miekg/dns, we require at least one question, // like upstream miekg/dns, we require at least one question,
@ -406,10 +585,26 @@ func getDynamicRouterConfig(conf *config.RuntimeConfig) (*RouterDynamicConfig, e
Refresh: conf.DNSSOA.Refresh, Refresh: conf.DNSSOA.Refresh,
Retry: conf.DNSSOA.Retry, Retry: conf.DNSSOA.Retry,
}, },
enterpriseDNSConfig: getEnterpriseDNSConfig(conf), EnterpriseDNSConfig: discovery.GetEnterpriseDNSConfig(conf),
} }
// TODO (v2-dns): add service TTL recalculation if conf.DNSServiceTTL != nil {
cfg.TTLRadix = radix.New()
cfg.TTLStrict = make(map[string]time.Duration)
for key, ttl := range conf.DNSServiceTTL {
// All suffix with '*' are put in radix
// This include '*' that will match anything
if strings.HasSuffix(key, "*") {
cfg.TTLRadix.Insert(key[:len(key)-1], ttl)
} else {
cfg.TTLStrict[key] = ttl
}
}
} else {
cfg.TTLRadix = nil
cfg.TTLStrict = nil
}
for _, r := range conf.DNSRecursors { for _, r := range conf.DNSRecursors {
ra, err := formatRecursorAddress(r) ra, err := formatRecursorAddress(r)
@ -545,30 +740,18 @@ func buildAddressResults(req *dns.Msg) ([]*discovery.Result, error) {
}, nil }, nil
} }
// buildQueryFromDNSMessage appends the discovery result to the dns message. // getAnswerAndExtra creates the dns answer and extra from discovery results.
func appendResultToDNSResponse(result *discovery.Result, req *dns.Msg, resp *dns.Msg, domain string, cfg *RouterDynamicConfig) { func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, reqCtx discovery.Context,
ip, ok := convertToIp(result) query *discovery.Query, cfg *RouterDynamicConfig, domain string, remoteAddress net.Addr, maxRecursionLevel int) (answer []dns.RR, extra []dns.RR, ns []dns.RR) {
address, target := getAddressAndTargetFromDiscoveryResult(result, r.domain)
// if the result is not an IP, we can try to recurse on the hostname. qName := req.Question[0].Name
// TODO (v2-dns): hostnames are valid for workloads in V2, do we just want to return the CNAME? ttlLookupName := qName
if !ok { if query != nil {
// TODO (v2-dns): recurse on HandleRequest() ttlLookupName = query.QueryPayload.Name
panic("not implemented")
} }
ttl := getTTLForResult(ttlLookupName, query, cfg)
var ttl uint32
switch result.Type {
case discovery.ResultTypeNode, discovery.ResultTypeVirtual, discovery.ResultTypeWorkload:
ttl = uint32(cfg.NodeTTL / time.Second)
case discovery.ResultTypeService:
// TODO (v2-dns): implement service TTL using the radix tree
}
qName := dns.CanonicalName(req.Question[0].Name)
qType := req.Question[0].Qtype qType := req.Question[0].Qtype
record, isIPV4 := makeRecord(qName, ip, ttl)
// TODO (v2-dns): skip records that refer to a workload/node that don't have a valid DNS name. // TODO (v2-dns): skip records that refer to a workload/node that don't have a valid DNS name.
// Special case responses // Special case responses
@ -579,54 +762,120 @@ func appendResultToDNSResponse(result *discovery.Result, req *dns.Msg, resp *dns
Hdr: dns.RR_Header{Name: qName, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: 0}, Hdr: dns.RR_Header{Name: qName, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: 0},
Ptr: canonicalNameForResult(result, domain), Ptr: canonicalNameForResult(result, domain),
} }
resp.Answer = append(resp.Answer, ptr) answer = append(answer, ptr)
return
case qType == dns.TypeNS: case qType == dns.TypeNS:
// TODO (v2-dns): fqdn in V1 has the datacenter included, this would need to be added to discovery.Result // TODO (v2-dns): fqdn in V1 has the datacenter included, this would need to be added to discovery.Result
fqdn := canonicalNameForResult(result, domain) fqdn := canonicalNameForResult(result, domain)
extraRecord, _ := makeRecord(fqdn, ip, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported extraRecord := makeIPBasedRecord(fqdn, address, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported
resp.Answer = append(resp.Ns, makeNSRecord(domain, fqdn, ttl))
resp.Extra = append(resp.Extra, extraRecord)
return
answer = append(answer, makeNSRecord(domain, fqdn, ttl))
extra = append(extra, extraRecord)
case qType == dns.TypeSOA: case qType == dns.TypeSOA:
// TODO (v2-dns): fqdn in V1 has the datacenter included, this would need to be added to discovery.Result // TODO (v2-dns): fqdn in V1 has the datacenter included, this would need to be added to discovery.Result
// to be returned in the result. // to be returned in the result.
fqdn := canonicalNameForResult(result, domain) fqdn := canonicalNameForResult(result, domain)
extraRecord, _ := makeRecord(fqdn, ip, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported extraRecord := makeIPBasedRecord(fqdn, address, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported
resp.Ns = append(resp.Ns, makeNSRecord(domain, fqdn, ttl)) ns = append(ns, makeNSRecord(domain, fqdn, ttl))
resp.Extra = append(resp.Extra, extraRecord) extra = append(extra, extraRecord)
return
case qType == dns.TypeSRV: case qType == dns.TypeSRV:
// We put A/AAAA/CNAME records in the additional section for SRV requests // We put A/AAAA/CNAME records in the additional section for SRV requests
resp.Extra = append(resp.Extra, record) a, e := r.getAnswerExtrasForAddressAndTarget(address, target, req, reqCtx,
result, ttl, remoteAddress, maxRecursionLevel)
answer = append(answer, a...)
extra = append(extra, e...)
// TODO (v2-dns): implement SRV records for the answer section cfg := r.dynamicConfig.Load().(*RouterDynamicConfig)
return if cfg.NodeMetaTXT {
extra = append(extra, makeTXTRecord(target.FQDN(), result, ttl)...)
} }
default:
// For explicit A/AAAA queries, we must only return those records in the answer section. a, e := r.getAnswerExtrasForAddressAndTarget(address, target, req, reqCtx,
if isIPV4 && qType != dns.TypeA && qType != dns.TypeANY { result, ttl, remoteAddress, maxRecursionLevel)
resp.Extra = append(resp.Extra, record) answer = append(answer, a...)
return extra = append(extra, e...)
} }
if !isIPV4 && qType != dns.TypeAAAA && qType != dns.TypeANY {
resp.Extra = append(resp.Extra, record)
return return
}
resp.Answer = append(resp.Answer, record)
} }
// convertToIp converts a discovery.Result to a net.IP. // getAnswerExtrasForAddressAndTarget creates the dns answer and extra from address and target dnsAddress pairs.
func convertToIp(result *discovery.Result) (net.IP, bool) { func (r *Router) getAnswerExtrasForAddressAndTarget(address *dnsAddress, target *dnsAddress, req *dns.Msg,
ip := net.ParseIP(result.Address) reqCtx discovery.Context, result *discovery.Result, ttl uint32, remoteAddress net.Addr,
if ip == nil { maxRecursionLevel int) (answer []dns.RR, extra []dns.RR) {
return nil, false qName := req.Question[0].Name
reqType := parseRequestType(req)
cfg := r.dynamicConfig.Load().(*RouterDynamicConfig)
switch {
// There is no target and the address is a FQDN (external service)
case address.IsFQDN():
a, e := r.makeRecordFromFQDN(address.FQDN(), result, req, reqCtx,
cfg, ttl, remoteAddress, maxRecursionLevel)
answer = append(a, answer...)
extra = append(e, extra...)
// The target is a FQDN (internal or external service name)
case result.Type != discovery.ResultTypeNode && target.IsFQDN():
a, e := r.makeRecordFromFQDN(target.FQDN(), result, req, reqCtx,
cfg, ttl, remoteAddress, maxRecursionLevel)
answer = append(answer, a...)
extra = append(extra, e...)
// There is no target and the address is an IP
case address.IsIP():
// TODO (v2-dns): Do not CNAME node address in case of WAN address.
ipRecordName := target.FQDN()
if maxRecursionLevel < maxRecursionLevelDefault || ipRecordName == "" {
ipRecordName = qName
} }
return ip, true a, e := getAnswerExtrasForIP(ipRecordName, address, req.Question[0], reqType, result, ttl)
answer = append(answer, a...)
extra = append(extra, e...)
// The target is an IP
case target.IsIP():
a, e := getAnswerExtrasForIP(qName, target, req.Question[0], reqType, result, ttl)
answer = append(answer, a...)
extra = append(extra, e...)
// The target is a CNAME for the service we are looking
// for. So we use the address.
case target.FQDN() == req.Question[0].Name && address.IsIP():
a, e := getAnswerExtrasForIP(qName, address, req.Question[0], reqType, result, ttl)
answer = append(answer, a...)
extra = append(extra, e...)
// The target is a FQDN (internal or external service name)
default:
a, e := r.makeRecordFromFQDN(target.FQDN(), result, req, reqCtx, cfg, ttl, remoteAddress, maxRecursionLevel)
answer = append(a, answer...)
extra = append(e, extra...)
}
return
}
// getAddressAndTargetFromDiscoveryResult returns the address and target from a discovery result.
func getAnswerExtrasForIP(name string, addr *dnsAddress, question dns.Question, reqType requestType, result *discovery.Result, ttl uint32) (answer []dns.RR, extra []dns.RR) {
record := makeIPBasedRecord(name, addr, ttl)
qType := question.Qtype
isARecordWhenNotExplicitlyQueried := record.Header().Rrtype == dns.TypeA && qType != dns.TypeA && qType != dns.TypeANY
isAAAARecordWhenNotExplicitlyQueried := record.Header().Rrtype == dns.TypeAAAA && qType != dns.TypeAAAA && qType != dns.TypeANY
// For explicit A/AAAA queries, we must only return those records in the answer section.
if isARecordWhenNotExplicitlyQueried ||
isAAAARecordWhenNotExplicitlyQueried {
extra = append(extra, record)
} else {
answer = append(answer, record)
}
if reqType != requestTypeAddress && qType == dns.TypeSRV {
srv := makeSRVRecord(name, name, result, ttl)
answer = append(answer, srv)
}
return
} }
func makeSOARecord(domain string, cfg *RouterDynamicConfig) dns.RR { func makeSOARecord(domain string, cfg *RouterDynamicConfig) dns.RR {
@ -660,13 +909,12 @@ func makeNSRecord(domain, fqdn string, ttl uint32) dns.RR {
} }
} }
// makeRecord an A or AAAA record for the given name and IP. // makeIPBasedRecord an A or AAAA record for the given name and IP.
// Note: we might want to pass in the Query Name here, which is used in addr. and virtual. queries // Note: we might want to pass in the Query Name here, which is used in addr. and virtual. queries
// since there is only ever one result. Right now choosing to leave it off for simplification. // since there is only ever one result. Right now choosing to leave it off for simplification.
func makeRecord(name string, ip net.IP, ttl uint32) (dns.RR, bool) { func makeIPBasedRecord(name string, addr *dnsAddress, ttl uint32) dns.RR {
isIPV4 := ip.To4() != nil
if isIPV4 { if addr.IsIPV4() {
// check if the query type is A for IPv4 or ANY // check if the query type is A for IPv4 or ANY
return &dns.A{ return &dns.A{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
@ -675,8 +923,8 @@ func makeRecord(name string, ip net.IP, ttl uint32) (dns.RR, bool) {
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: ttl, Ttl: ttl,
}, },
A: ip, A: addr.IP(),
}, true }
} }
return &dns.AAAA{ return &dns.AAAA{
@ -686,6 +934,126 @@ func makeRecord(name string, ip net.IP, ttl uint32) (dns.RR, bool) {
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: ttl, Ttl: ttl,
}, },
AAAA: ip, AAAA: addr.IP(),
}, false }
}
func (r *Router) makeRecordFromFQDN(fqdn string, result *discovery.Result,
req *dns.Msg, reqCtx discovery.Context, cfg *RouterDynamicConfig, ttl uint32,
remoteAddress net.Addr, maxRecursionLevel int) ([]dns.RR, []dns.RR) {
edns := req.IsEdns0() != nil
q := req.Question[0]
more := r.resolveCNAME(cfg, dns.Fqdn(fqdn), reqCtx, remoteAddress, maxRecursionLevel)
var additional []dns.RR
extra := 0
MORE_REC:
for _, rr := range more {
switch rr.Header().Rrtype {
case dns.TypeCNAME, dns.TypeA, dns.TypeAAAA:
// set the TTL manually
rr.Header().Ttl = ttl
additional = append(additional, rr)
extra++
if extra == maxRecurseRecords && !edns {
break MORE_REC
}
}
}
if q.Qtype == dns.TypeSRV {
answers := []dns.RR{
makeSRVRecord(q.Name, fqdn, result, ttl),
}
return answers, additional
}
answers := []dns.RR{
makeCNAMERecord(result, q.Name, ttl),
}
answers = append(answers, additional...)
return answers, nil
}
// makeCNAMERecord returns a CNAME record for the given name and target.
func makeCNAMERecord(result *discovery.Result, qName string, ttl uint32) *dns.CNAME {
return &dns.CNAME{
Hdr: dns.RR_Header{
Name: qName,
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
Ttl: ttl,
},
Target: dns.Fqdn(result.Target),
}
}
// func makeSRVRecord returns an SRV record for the given name and target.
func makeSRVRecord(name, target string, result *discovery.Result, ttl uint32) *dns.SRV {
return &dns.SRV{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: ttl,
},
Priority: 1,
Weight: uint16(result.Weight),
Port: uint16(result.Port),
Target: target,
}
}
// encodeKVasRFC1464 encodes a key-value pair according to RFC1464
func encodeKVasRFC1464(key, value string) (txt string) {
// For details on these replacements c.f. https://www.ietf.org/rfc/rfc1464.txt
key = strings.Replace(key, "`", "``", -1)
key = strings.Replace(key, "=", "`=", -1)
// Backquote the leading spaces
leadingSpacesRE := regexp.MustCompile("^ +")
numLeadingSpaces := len(leadingSpacesRE.FindString(key))
key = leadingSpacesRE.ReplaceAllString(key, strings.Repeat("` ", numLeadingSpaces))
// Backquote the trailing spaces
numTrailingSpaces := len(trailingSpacesRE.FindString(key))
key = trailingSpacesRE.ReplaceAllString(key, strings.Repeat("` ", numTrailingSpaces))
value = strings.Replace(value, "`", "``", -1)
return key + "=" + value
}
// makeTXTRecord returns a TXT record for the given name and result metadata.
func makeTXTRecord(name string, result *discovery.Result, ttl uint32) []dns.RR {
extra := make([]dns.RR, 0, len(result.Metadata))
for key, value := range result.Metadata {
txt := value
if !strings.HasPrefix(strings.ToLower(key), "rfc1035-") {
txt = encodeKVasRFC1464(key, value)
}
extra = append(extra, &dns.TXT{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: ttl,
},
Txt: []string{txt},
})
}
return extra
}
// getAddressAndTargetFromCheckServiceNode returns the address and target for a given discovery.Result
func getAddressAndTargetFromDiscoveryResult(result *discovery.Result, domain string) (*dnsAddress, *dnsAddress) {
target := newDNSAddress(result.Target)
if !target.IsEmptyString() && !target.IsInternalFQDNOrIP(domain) {
target.SetAddress(canonicalNameForResult(result, domain))
}
address := newDNSAddress(result.Address)
return address, target
} }

View File

@ -11,6 +11,7 @@ import (
"github.com/hashicorp/consul/agent/discovery" "github.com/hashicorp/consul/agent/discovery"
) )
// canonicalNameForResult returns the canonical name for a discovery result.
func canonicalNameForResult(result *discovery.Result, domain string) string { func canonicalNameForResult(result *discovery.Result, domain string) string {
switch result.Type { switch result.Type {
case discovery.ResultTypeService: case discovery.ResultTypeService:

View File

@ -14,44 +14,77 @@ import (
) )
// buildQueryFromDNSMessage returns a discovery.Query from a DNS message. // buildQueryFromDNSMessage returns a discovery.Query from a DNS message.
func buildQueryFromDNSMessage(req *dns.Msg, domain, altDomain string, cfgCtx *RouterDynamicConfig, defaultEntMeta acl.EnterpriseMeta) (*discovery.Query, error) { func buildQueryFromDNSMessage(req *dns.Msg, domain, altDomain string,
cfg *RouterDynamicConfig, defaultEntMeta acl.EnterpriseMeta, defaultDatacenter string) (*discovery.Query, error) {
queryType, queryParts, querySuffixes := getQueryTypePartsAndSuffixesFromDNSMessage(req, domain, altDomain) queryType, queryParts, querySuffixes := getQueryTypePartsAndSuffixesFromDNSMessage(req, domain, altDomain)
locality, ok := ParseLocality(querySuffixes, defaultEntMeta, cfgCtx.enterpriseDNSConfig) queryTenancy, err := getQueryTenancy(queryType, querySuffixes, defaultEntMeta, cfg, defaultDatacenter)
if !ok { if err != nil {
return nil, errors.New("invalid locality") return nil, err
} }
// TODO(v2-dns): This needs to be deprecated. name, tag := getQueryNameAndTagFromParts(queryType, queryParts)
peerName := locality.peer
if peerName == "" {
// If the peer name was not explicitly defined, fall back to the ambiguously-parsed version.
peerName = locality.peerOrDatacenter
}
return &discovery.Query{ return &discovery.Query{
QueryType: queryType, QueryType: queryType,
QueryPayload: discovery.QueryPayload{ QueryPayload: discovery.QueryPayload{
Name: queryParts[len(queryParts)-1], Name: name,
Tenancy: discovery.QueryTenancy{ Tenancy: queryTenancy,
EnterpriseMeta: locality.EnterpriseMeta, Tag: tag,
// v2-dns: revisit if we need this after the rest of this works. // TODO (v2-dns): what should these be?
// SamenessGroup: "",
// The datacenter of the request is not specified because cross-datacenter virtual IP
// queries are not supported. This guard rail is in place because virtual IPs are allocated
// within a DC, therefore their uniqueness is not guaranteed globally.
Peer: peerName,
Datacenter: locality.datacenter,
},
// TODO(v2-dns): what should these be?
//PortName: "", //PortName: "",
//Tag: "",
//RemoteAddr: nil, //RemoteAddr: nil,
//DisableFailover: false, //DisableFailover: false,
}, },
}, nil }, nil
} }
// getQueryNameAndTagFromParts returns the query name and tag from the query parts that are taken from the original dns question.
func getQueryNameAndTagFromParts(queryType discovery.QueryType, queryParts []string) (string, string) {
switch queryType {
case discovery.QueryTypeService:
n := len(queryParts)
// Support RFC 2782 style syntax
if n == 2 && strings.HasPrefix(queryParts[1], "_") && strings.HasPrefix(queryParts[0], "_") {
// Grab the tag since we make nuke it if it's tcp
tag := queryParts[1][1:]
// Treat _name._tcp.service.consul as a default, no need to filter on that tag
if tag == "tcp" {
tag = ""
}
name := queryParts[0][1:]
// _name._tag.service.consul
return name, tag
}
return queryParts[len(queryParts)-1], ""
}
return queryParts[len(queryParts)-1], ""
}
// getQueryTenancy returns a discovery.QueryTenancy from a DNS message.
func getQueryTenancy(queryType discovery.QueryType, querySuffixes []string,
defaultEntMeta acl.EnterpriseMeta, cfg *RouterDynamicConfig, defaultDatacenter string) (discovery.QueryTenancy, error) {
if queryType == discovery.QueryTypeService {
return getQueryTenancyForService(querySuffixes, defaultEntMeta, cfg, defaultDatacenter)
}
locality, ok := discovery.ParseLocality(querySuffixes, defaultEntMeta, cfg.EnterpriseDNSConfig)
if !ok {
return discovery.QueryTenancy{}, errors.New("invalid locality")
}
if queryType == discovery.QueryTypeVirtual {
if locality.Peer == "" {
// If the peer name was not explicitly defined, fall back to the ambiguously-parsed version.
locality.Peer = locality.PeerOrDatacenter
}
}
return discovery.GetQueryTenancyBasedOnLocality(locality, defaultDatacenter)
}
// getQueryTypePartsAndSuffixesFromDNSMessage returns the query type, the parts, and suffixes of the query name. // getQueryTypePartsAndSuffixesFromDNSMessage returns the query type, the parts, and suffixes of the query name.
func getQueryTypePartsAndSuffixesFromDNSMessage(req *dns.Msg, domain, altDomain string) (queryType discovery.QueryType, parts []string, suffixes []string) { func getQueryTypePartsAndSuffixesFromDNSMessage(req *dns.Msg, domain, altDomain string) (queryType discovery.QueryType, parts []string, suffixes []string) {
// Get the QName without the domain suffix // Get the QName without the domain suffix
@ -64,18 +97,19 @@ func getQueryTypePartsAndSuffixesFromDNSMessage(req *dns.Msg, domain, altDomain
for i := len(labels) - 1; i >= 0 && !done; i-- { for i := len(labels) - 1; i >= 0 && !done; i-- {
queryType = getQueryTypeFromLabels(labels[i]) queryType = getQueryTypeFromLabels(labels[i])
switch queryType { switch queryType {
case discovery.QueryTypeInvalid:
// If we don't recognize the query type, we keep going until we find one we do.
case discovery.QueryTypeService, case discovery.QueryTypeService,
discovery.QueryTypeConnect, discovery.QueryTypeVirtual, discovery.QueryTypeIngress, discovery.QueryTypeConnect, discovery.QueryTypeVirtual, discovery.QueryTypeIngress,
discovery.QueryTypeNode, discovery.QueryTypePreparedQuery: discovery.QueryTypeNode, discovery.QueryTypePreparedQuery:
parts = labels[:i] parts = labels[:i]
suffixes = labels[i+1:] suffixes = labels[i+1:]
done = true done = true
case discovery.QueryTypeInvalid:
fallthrough
default: default:
// If this is a SRV query the "service" label is optional, we add it back to use the // If this is a SRV query the "service" label is optional, we add it back to use the
// existing code-path. // existing code-path.
if req.Question[0].Qtype == dns.TypeSRV && strings.HasPrefix(labels[i], "_") { if req.Question[0].Qtype == dns.TypeSRV && strings.HasPrefix(labels[i], "_") {
queryType = discovery.QueryTypeService
parts = labels[:i+1] parts = labels[:i+1]
suffixes = labels[i+1:] suffixes = labels[i+1:]
done = true done = true

View File

@ -0,0 +1,24 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !consulent
package dns
import (
"errors"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/discovery"
)
// getQueryTenancy returns a discovery.QueryTenancy from a DNS message.
func getQueryTenancyForService(querySuffixes []string,
defaultEntMeta acl.EnterpriseMeta, cfg *RouterDynamicConfig, defaultDatacenter string) (discovery.QueryTenancy, error) {
locality, ok := discovery.ParseLocality(querySuffixes, defaultEntMeta, cfg.EnterpriseDNSConfig)
if !ok {
return discovery.QueryTenancy{}, errors.New("invalid locality")
}
return discovery.GetQueryTenancyBasedOnLocality(locality, defaultDatacenter)
}

View File

@ -29,7 +29,7 @@ func Test_buildQueryFromDNSMessage(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
query, err := buildQueryFromDNSMessage(tc.request, "domain", "altDomain", &RouterDynamicConfig{}, acl.EnterpriseMeta{}) query, err := buildQueryFromDNSMessage(tc.request, "domain", "altDomain", &RouterDynamicConfig{}, acl.EnterpriseMeta{}, "defaultDatacenter")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tc.expectedQuery, query) assert.Equal(t, tc.expectedQuery, query)
}) })

View File

@ -706,7 +706,7 @@ func Test_HandleRequest(t *testing.T) {
}, },
Question: []dns.Question{ Question: []dns.Question{
{ {
Name: "c000020a.virtual.consul", // "intentionally missing the trailing dot" Name: "c000020a.virtual.dc1.consul", // "intentionally missing the trailing dot"
Qtype: dns.TypeA, Qtype: dns.TypeA,
Qclass: dns.ClassINET, Qclass: dns.ClassINET,
}, },
@ -728,7 +728,7 @@ func Test_HandleRequest(t *testing.T) {
Compress: true, Compress: true,
Question: []dns.Question{ Question: []dns.Question{
{ {
Name: "c000020a.virtual.consul.", Name: "c000020a.virtual.dc1.consul.",
Qtype: dns.TypeA, Qtype: dns.TypeA,
Qclass: dns.ClassINET, Qclass: dns.ClassINET,
}, },
@ -736,7 +736,7 @@ func Test_HandleRequest(t *testing.T) {
Answer: []dns.RR{ Answer: []dns.RR{
&dns.A{ &dns.A{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: "c000020a.virtual.consul.", Name: "c000020a.virtual.dc1.consul.",
Rrtype: dns.TypeA, Rrtype: dns.TypeA,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: 123, Ttl: 123,
@ -1345,6 +1345,58 @@ func Test_HandleRequest(t *testing.T) {
} }
func TestRouterDynamicConfig_GetTTLForService(t *testing.T) {
type testCase struct {
name string
inputKey string
shouldMatch bool
expectedDuration time.Duration
}
testCases := []testCase{
{
name: "strict match",
inputKey: "foo",
shouldMatch: true,
expectedDuration: 1 * time.Second,
},
{
name: "wildcard match",
inputKey: "bar",
shouldMatch: true,
expectedDuration: 2 * time.Second,
},
{
name: "wildcard match 2",
inputKey: "bart",
shouldMatch: true,
expectedDuration: 2 * time.Second,
},
{
name: "no match",
inputKey: "homer",
shouldMatch: false,
expectedDuration: 0 * time.Second,
},
}
rtCfg := &config.RuntimeConfig{
DNSServiceTTL: map[string]time.Duration{
"foo": 1 * time.Second,
"bar*": 2 * time.Second,
},
}
cfg, err := getDynamicRouterConfig(rtCfg)
require.NoError(t, err)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actual, ok := cfg.GetTTLForService(tc.inputKey)
require.Equal(t, tc.shouldMatch, ok)
require.Equal(t, tc.expectedDuration, actual)
})
}
}
func buildDNSConfig(agentConfig *config.RuntimeConfig, cdf discovery.CatalogDataFetcher, _ error) Config { func buildDNSConfig(agentConfig *config.RuntimeConfig, cdf discovery.CatalogDataFetcher, _ error) Config {
cfg := Config{ cfg := Config{
AgentConfig: &config.RuntimeConfig{ AgentConfig: &config.RuntimeConfig{

View File

@ -20,6 +20,7 @@ import (
"github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/testrpc"
) )
// TODO (v2-dns): requires PTR implementation
func TestDNS_ServiceReverseLookup(t *testing.T) { func TestDNS_ServiceReverseLookup(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -76,6 +77,7 @@ func TestDNS_ServiceReverseLookup(t *testing.T) {
} }
} }
// TODO (v2-dns): requires PTR implementation
func TestDNS_ServiceReverseLookup_IPV6(t *testing.T) { func TestDNS_ServiceReverseLookup_IPV6(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -132,6 +134,7 @@ func TestDNS_ServiceReverseLookup_IPV6(t *testing.T) {
} }
} }
// TODO (v2-dns): requires PTR implementation
func TestDNS_ServiceReverseLookup_CustomDomain(t *testing.T) { func TestDNS_ServiceReverseLookup_CustomDomain(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -190,6 +193,7 @@ func TestDNS_ServiceReverseLookup_CustomDomain(t *testing.T) {
} }
} }
// TODO (v2-dns): requires PTR implementation
func TestDNS_ServiceReverseLookupNodeAddress(t *testing.T) { func TestDNS_ServiceReverseLookupNodeAddress(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -252,7 +256,7 @@ func TestDNS_ServiceLookupNoMultiCNAME(t *testing.T) {
} }
t.Parallel() t.Parallel()
for name, experimentsHCL := range getVersionHCL(false) { for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, experimentsHCL) a := NewTestAgent(t, experimentsHCL)
defer a.Shutdown() defer a.Shutdown()
@ -315,7 +319,7 @@ func TestDNS_ServiceLookupPreferNoCNAME(t *testing.T) {
} }
t.Parallel() t.Parallel()
for name, experimentsHCL := range getVersionHCL(false) { for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, experimentsHCL) a := NewTestAgent(t, experimentsHCL)
defer a.Shutdown() defer a.Shutdown()
@ -364,7 +368,7 @@ func TestDNS_ServiceLookupPreferNoCNAME(t *testing.T) {
in, _, err := c.Exchange(m, a.DNSAddr()) in, _, err := c.Exchange(m, a.DNSAddr())
require.NoError(t, err) require.NoError(t, err)
// expect a CNAME and an A RR // expect an A RR
require.Len(t, in.Answer, 1) require.Len(t, in.Answer, 1)
aRec, ok := in.Answer[0].(*dns.A) aRec, ok := in.Answer[0].(*dns.A)
require.Truef(t, ok, "Not an A RR") require.Truef(t, ok, "Not an A RR")
@ -381,7 +385,7 @@ func TestDNS_ServiceLookupMultiAddrNoCNAME(t *testing.T) {
} }
t.Parallel() t.Parallel()
for name, experimentsHCL := range getVersionHCL(false) { for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, experimentsHCL) a := NewTestAgent(t, experimentsHCL)
defer a.Shutdown() defer a.Shutdown()
@ -457,6 +461,7 @@ func TestDNS_ServiceLookupMultiAddrNoCNAME(t *testing.T) {
} }
} }
// TODO (v2-dns): this requires a prepared query
func TestDNS_ServiceLookup(t *testing.T) { func TestDNS_ServiceLookup(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -585,6 +590,8 @@ func TestDNS_ServiceLookup(t *testing.T) {
} }
} }
// TODO (v2-dns): this is formulating the correct response
// but failing with an I/O timeout on the dns client Exchange() call
func TestDNS_ServiceLookupWithInternalServiceAddress(t *testing.T) { func TestDNS_ServiceLookupWithInternalServiceAddress(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -820,7 +827,7 @@ func TestDNS_ExternalServiceLookup(t *testing.T) {
} }
t.Parallel() t.Parallel()
for name, experimentsHCL := range getVersionHCL(false) { for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, experimentsHCL) a := NewTestAgent(t, experimentsHCL)
defer a.Shutdown() defer a.Shutdown()
@ -858,7 +865,7 @@ func TestDNS_ExternalServiceLookup(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if len(in.Answer) != 1 { if len(in.Answer) != 1 || len(in.Extra) > 0 {
t.Fatalf("Bad: %#v", in) t.Fatalf("Bad: %#v", in)
} }
@ -886,7 +893,7 @@ func TestDNS_ExternalServiceToConsulCNAMELookup(t *testing.T) {
} }
t.Parallel() t.Parallel()
for name, experimentsHCL := range getVersionHCL(false) { for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, ` a := NewTestAgent(t, `
domain = "CONSUL." domain = "CONSUL."
@ -1121,6 +1128,7 @@ func TestDNS_ExternalServiceToConsulCNAMENestedLookup(t *testing.T) {
} }
} }
// TODO (v2-dns): this requires a prepared query
func TestDNS_ServiceLookup_ServiceAddress_A(t *testing.T) { func TestDNS_ServiceLookup_ServiceAddress_A(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -1222,6 +1230,7 @@ func TestDNS_ServiceLookup_ServiceAddress_A(t *testing.T) {
} }
} }
// TODO (v2-dns): this requires a prepared query
func TestDNS_AltDomain_ServiceLookup_ServiceAddress_A(t *testing.T) { func TestDNS_AltDomain_ServiceLookup_ServiceAddress_A(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -1330,6 +1339,7 @@ func TestDNS_AltDomain_ServiceLookup_ServiceAddress_A(t *testing.T) {
} }
} }
// TODO (v2-dns): this requires a prepared query
func TestDNS_ServiceLookup_ServiceAddress_SRV(t *testing.T) { func TestDNS_ServiceLookup_ServiceAddress_SRV(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -1451,6 +1461,7 @@ func TestDNS_ServiceLookup_ServiceAddress_SRV(t *testing.T) {
} }
} }
// TODO (v2-dns): this requires a prepared query
func TestDNS_ServiceLookup_ServiceAddressIPV6(t *testing.T) { func TestDNS_ServiceLookup_ServiceAddressIPV6(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -1552,6 +1563,7 @@ func TestDNS_ServiceLookup_ServiceAddressIPV6(t *testing.T) {
} }
} }
// TODO (v2-dns): this requires a prepared query
func TestDNS_AltDomain_ServiceLookup_ServiceAddressIPV6(t *testing.T) { func TestDNS_AltDomain_ServiceLookup_ServiceAddressIPV6(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -1660,6 +1672,7 @@ func TestDNS_AltDomain_ServiceLookup_ServiceAddressIPV6(t *testing.T) {
} }
} }
// TODO (v2-dns): this requires a prepared query
func TestDNS_ServiceLookup_WanTranslation(t *testing.T) { func TestDNS_ServiceLookup_WanTranslation(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -1876,6 +1889,7 @@ func TestDNS_ServiceLookup_WanTranslation(t *testing.T) {
} }
} }
// TODO (v2-dns): this requires a prepared query
func TestDNS_CaseInsensitiveServiceLookup(t *testing.T) { func TestDNS_CaseInsensitiveServiceLookup(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -1978,6 +1992,7 @@ func TestDNS_CaseInsensitiveServiceLookup(t *testing.T) {
} }
} }
// TODO (v2-dns): this returns a response where the answer is an SOA record
func TestDNS_ServiceLookup_TagPeriod(t *testing.T) { func TestDNS_ServiceLookup_TagPeriod(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -2058,6 +2073,7 @@ func TestDNS_ServiceLookup_TagPeriod(t *testing.T) {
} }
} }
// TODO (v2-dns): this returns a response where the answer is an SOA record
func TestDNS_ServiceLookup_PreparedQueryNamePeriod(t *testing.T) { func TestDNS_ServiceLookup_PreparedQueryNamePeriod(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -2145,6 +2161,7 @@ func TestDNS_ServiceLookup_PreparedQueryNamePeriod(t *testing.T) {
} }
} }
// TODO (v2-dns): this requires a prepared query
func TestDNS_ServiceLookup_Dedup(t *testing.T) { func TestDNS_ServiceLookup_Dedup(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -2256,6 +2273,7 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) {
} }
} }
// TODO (v2-dns): this requires a prepared query
func TestDNS_ServiceLookup_Dedup_SRV(t *testing.T) { func TestDNS_ServiceLookup_Dedup_SRV(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -2831,6 +2849,7 @@ func TestDNS_ServiceLookup_OnlyPassing(t *testing.T) {
} }
} }
// TODO (v2-dns): this requires a prepared query
func TestDNS_ServiceLookup_Randomize(t *testing.T) { func TestDNS_ServiceLookup_Randomize(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -2930,6 +2949,7 @@ func TestDNS_ServiceLookup_Randomize(t *testing.T) {
} }
} }
// TODO (v2-dns): this requires a prepared query
func TestDNS_ServiceLookup_Truncate(t *testing.T) { func TestDNS_ServiceLookup_Truncate(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -3007,6 +3027,7 @@ func TestDNS_ServiceLookup_Truncate(t *testing.T) {
} }
} }
// TODO (v2-dns): this requires a prepared query
func TestDNS_ServiceLookup_LargeResponses(t *testing.T) { func TestDNS_ServiceLookup_LargeResponses(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -3386,7 +3407,6 @@ func TestDNS_ServiceLookup_ARecordLimits(t *testing.T) {
} }
} }
// TODO(jmurret):
func TestDNS_ServiceLookup_AnswerLimits(t *testing.T) { func TestDNS_ServiceLookup_AnswerLimits(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -3462,6 +3482,7 @@ func TestDNS_ServiceLookup_AnswerLimits(t *testing.T) {
} }
} }
// TODO (v2-dns): this requires a prepared query
func TestDNS_ServiceLookup_CNAME(t *testing.T) { func TestDNS_ServiceLookup_CNAME(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -3567,6 +3588,7 @@ func TestDNS_ServiceLookup_CNAME(t *testing.T) {
} }
} }
// TODO (v2-dns): this requires a prepared query
func TestDNS_ServiceLookup_ServiceAddress_CNAME(t *testing.T) { func TestDNS_ServiceLookup_ServiceAddress_CNAME(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -3679,7 +3701,7 @@ func TestDNS_ServiceLookup_TTL(t *testing.T) {
} }
t.Parallel() t.Parallel()
for name, experimentsHCL := range getVersionHCL(false) { for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, ` a := NewTestAgent(t, `
dns_config { dns_config {
@ -3765,7 +3787,7 @@ func TestDNS_ServiceLookup_SRV_RFC(t *testing.T) {
} }
t.Parallel() t.Parallel()
for name, experimentsHCL := range getVersionHCL(false) { for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, experimentsHCL) a := NewTestAgent(t, experimentsHCL)
defer a.Shutdown() defer a.Shutdown()
@ -3847,7 +3869,9 @@ func TestDNS_ServiceLookup_SRV_RFC_TCP_Default(t *testing.T) {
} }
t.Parallel() t.Parallel()
a := NewTestAgent(t, "") for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, experimentsHCL)
defer a.Shutdown() defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1") testrpc.WaitForLeader(t, a.RPC, "dc1")
@ -3876,6 +3900,7 @@ func TestDNS_ServiceLookup_SRV_RFC_TCP_Default(t *testing.T) {
} }
for _, question := range questions { for _, question := range questions {
t.Run(question, func(t *testing.T) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion(question, dns.TypeSRV) m.SetQuestion(question, dns.TypeSRV)
@ -3916,6 +3941,9 @@ func TestDNS_ServiceLookup_SRV_RFC_TCP_Default(t *testing.T) {
if aRec.Hdr.Ttl != 0 { if aRec.Hdr.Ttl != 0 {
t.Fatalf("Bad: %#v", in.Extra[0]) t.Fatalf("Bad: %#v", in.Extra[0])
} }
})
}
})
} }
} }
@ -3982,7 +4010,9 @@ func TestDNS_ServiceLookup_FilterACL(t *testing.T) {
} }
` `
a := NewTestAgent(t, hcl) for name, experimentsHCL := range getVersionHCL(false) {
t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, hcl+experimentsHCL)
defer a.Shutdown() defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1") testrpc.WaitForLeader(t, a.RPC, "dc1")
@ -4020,6 +4050,8 @@ func TestDNS_ServiceLookup_FilterACL(t *testing.T) {
} }
}) })
} }
})
}
} }
func TestDNS_ServiceLookup_MetaTXT(t *testing.T) { func TestDNS_ServiceLookup_MetaTXT(t *testing.T) {
@ -4027,7 +4059,9 @@ func TestDNS_ServiceLookup_MetaTXT(t *testing.T) {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
} }
a := NewTestAgent(t, `dns_config = { enable_additional_node_meta_txt = true }`) for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, `dns_config = { enable_additional_node_meta_txt = true } `+experimentsHCL)
defer a.Shutdown() defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1") testrpc.WaitForLeader(t, a.RPC, "dc1")
@ -4070,6 +4104,8 @@ func TestDNS_ServiceLookup_MetaTXT(t *testing.T) {
}, },
} }
require.Equal(t, wantAdditional, in.Extra) require.Equal(t, wantAdditional, in.Extra)
})
}
} }
func TestDNS_ServiceLookup_SuppressTXT(t *testing.T) { func TestDNS_ServiceLookup_SuppressTXT(t *testing.T) {
@ -4077,7 +4113,9 @@ func TestDNS_ServiceLookup_SuppressTXT(t *testing.T) {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
} }
a := NewTestAgent(t, `dns_config = { enable_additional_node_meta_txt = false }`) for name, experimentsHCL := range getVersionHCL(true) {
t.Run(name, func(t *testing.T) {
a := NewTestAgent(t, `dns_config = { enable_additional_node_meta_txt = false } `+experimentsHCL)
defer a.Shutdown() defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1") testrpc.WaitForLeader(t, a.RPC, "dc1")
@ -4117,4 +4155,6 @@ func TestDNS_ServiceLookup_SuppressTXT(t *testing.T) {
}, },
} }
require.Equal(t, wantAdditional, in.Extra) require.Equal(t, wantAdditional, in.Extra)
})
}
} }