diff --git a/agent/dns/mock_dnsRecursor.go b/agent/dns/mock_dnsRecursor.go new file mode 100644 index 0000000000..83f41a30ed --- /dev/null +++ b/agent/dns/mock_dnsRecursor.go @@ -0,0 +1,56 @@ +// Code generated by mockery v2.20.0. DO NOT EDIT. + +package dns + +import ( + miekgdns "github.com/miekg/dns" + mock "github.com/stretchr/testify/mock" + + net "net" +) + +// mockDnsRecursor is an autogenerated mock type for the dnsRecursor type +type mockDnsRecursor struct { + mock.Mock +} + +// handle provides a mock function with given fields: req, cfgCtx, remoteAddr +func (_m *mockDnsRecursor) handle(req *miekgdns.Msg, cfgCtx *RouterDynamicConfig, remoteAddr net.Addr) (*miekgdns.Msg, error) { + ret := _m.Called(req, cfgCtx, remoteAddr) + + var r0 *miekgdns.Msg + var r1 error + if rf, ok := ret.Get(0).(func(*miekgdns.Msg, *RouterDynamicConfig, net.Addr) (*miekgdns.Msg, error)); ok { + return rf(req, cfgCtx, remoteAddr) + } + if rf, ok := ret.Get(0).(func(*miekgdns.Msg, *RouterDynamicConfig, net.Addr) *miekgdns.Msg); ok { + r0 = rf(req, cfgCtx, remoteAddr) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*miekgdns.Msg) + } + } + + if rf, ok := ret.Get(1).(func(*miekgdns.Msg, *RouterDynamicConfig, net.Addr) error); ok { + r1 = rf(req, cfgCtx, remoteAddr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type mockConstructorTestingTnewMockDnsRecursor interface { + mock.TestingT + Cleanup(func()) +} + +// newMockDnsRecursor creates a new instance of mockDnsRecursor. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func newMockDnsRecursor(t mockConstructorTestingTnewMockDnsRecursor) *mockDnsRecursor { + mock := &mockDnsRecursor{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/agent/dns/recursor.go b/agent/dns/recursor.go new file mode 100644 index 0000000000..55b922f710 --- /dev/null +++ b/agent/dns/recursor.go @@ -0,0 +1,123 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package dns + +import ( + "errors" + "net" + "time" + + "github.com/hashicorp/go-hclog" + "github.com/miekg/dns" + + "github.com/hashicorp/consul/ipaddr" + "github.com/hashicorp/consul/logging" +) + +type recursor struct { + logger hclog.Logger +} + +func newRecursor(logger hclog.Logger) *recursor { + return &recursor{ + logger: logger.Named(logging.DNS), + } +} + +// handle is used to process DNS queries for externally configured servers +func (r *recursor) handle(req *dns.Msg, cfgCtx *RouterDynamicConfig, remoteAddr net.Addr) (*dns.Msg, error) { + q := req.Question[0] + + network := "udp" + defer func(s time.Time) { + r.logger.Debug("request served from client", + "question", q, + "network", network, + "latency", time.Since(s).String(), + "client", remoteAddr.String(), + "client_network", remoteAddr.Network(), + ) + }(time.Now()) + + // Switch to TCP if the client is + if _, ok := remoteAddr.(*net.TCPAddr); ok { + network = "tcp" + } + + // Recursively resolve + c := &dns.Client{Net: network, Timeout: cfgCtx.RecursorTimeout} + var resp *dns.Msg + var rtt time.Duration + var err error + for _, idx := range cfgCtx.RecursorStrategy.Indexes(len(cfgCtx.Recursors)) { + recurseAddr := cfgCtx.Recursors[idx] + resp, rtt, err = c.Exchange(req, recurseAddr) + // Check if the response is valid and has the desired Response code + if resp != nil && (resp.Rcode != dns.RcodeSuccess && resp.Rcode != dns.RcodeNameError) { + r.logger.Debug("recurse failed for question", + "question", q, + "rtt", rtt, + "recursor", recurseAddr, + "rcode", dns.RcodeToString[resp.Rcode], + ) + // If we still have recursors to forward the query to, + // we move forward onto the next one else the loop ends + continue + } else if err == nil || (resp != nil && resp.Truncated) { + // Compress the response; we don't know if the incoming + // response was compressed or not, so by not compressing + // we might generate an invalid packet on the way out. + resp.Compress = !cfgCtx.DisableCompression + + // Forward the response + r.logger.Debug("recurse succeeded for question", + "question", q, + "rtt", rtt, + "recursor", recurseAddr, + ) + return resp, nil + } + r.logger.Error("recurse failed", "error", err) + } + + // If all resolvers fail, return a SERVFAIL message + r.logger.Error("all resolvers failed for question from client", + "question", q, + "client", remoteAddr.String(), + "client_network", remoteAddr.Network(), + ) + + return nil, errRecursionFailed +} + +// formatRecursorAddress is used to add a port to the recursor if omitted. +func formatRecursorAddress(recursor string) (string, error) { + _, _, err := net.SplitHostPort(recursor) + var ae *net.AddrError + if errors.As(err, &ae) { + switch ae.Err { + case "missing port in address": + recursor = ipaddr.FormatAddressPort(recursor, 53) + case "too many colons in address": + if ip := net.ParseIP(recursor); ip != nil && ip.To4() == nil { + recursor = ipaddr.FormatAddressPort(recursor, 53) + break + } + fallthrough + default: + return "", err + } + } else if err != nil { + return "", err + } + + // Get the address + addr, err := net.ResolveTCPAddr("tcp", recursor) + if err != nil { + return "", err + } + + // Return string + return addr.String(), nil +} diff --git a/agent/dns/recursor_test.go b/agent/dns/recursor_test.go new file mode 100644 index 0000000000..69514e508e --- /dev/null +++ b/agent/dns/recursor_test.go @@ -0,0 +1,39 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package dns + +import ( + "strings" + "testing" +) + +// Test_handle cases are covered by the integration tests in agent/dns_test.go. +// They should be moved here when the V1 DNS server is deprecated. +//func Test_handle(t *testing.T) { + +func Test_formatRecursorAddress(t *testing.T) { + t.Parallel() + addr, err := formatRecursorAddress("8.8.8.8") + if err != nil { + t.Fatalf("err: %v", err) + } + if addr != "8.8.8.8:53" { + t.Fatalf("bad: %v", addr) + } + addr, err = formatRecursorAddress("2001:4860:4860::8888") + if err != nil { + t.Fatalf("err: %v", err) + } + if addr != "[2001:4860:4860::8888]:53" { + t.Fatalf("bad: %v", addr) + } + _, err = formatRecursorAddress("1.2.3.4::53") + if err == nil || !strings.Contains(err.Error(), "too many colons in address") { + t.Fatalf("err: %v", err) + } + _, err = formatRecursorAddress("2001:4860:4860::8888:::53") + if err == nil || !strings.Contains(err.Error(), "too many colons in address") { + t.Fatalf("err: %v", err) + } +} diff --git a/agent/dns/router.go b/agent/dns/router.go index bf6b29f077..2a86592a9c 100644 --- a/agent/dns/router.go +++ b/agent/dns/router.go @@ -33,7 +33,8 @@ const ( var ( errInvalidQuestion = fmt.Errorf("invalid question") - errNameNotFound = fmt.Errorf("invalid question") + errNameNotFound = fmt.Errorf("name not found") + errRecursionFailed = fmt.Errorf("recursion failed") ) // TODO (v2-dns): metrics @@ -74,10 +75,18 @@ type DiscoveryQueryProcessor interface { QueryByIP(net.IP, discovery.Context) ([]*discovery.Result, error) } +// dnsRecursor is an interface that can be used to mock calls to external DNS servers for unit testing. +// +//go:generate mockery --name dnsRecursor --inpackage +type dnsRecursor interface { + handle(req *dns.Msg, cfgCtx *RouterDynamicConfig, remoteAddr net.Addr) (*dns.Msg, error) +} + // Router replaces miekg/dns.ServeMux with a simpler router that only checks for the 2-3 valid domains // that Consul supports and forwards to a single DiscoveryQueryProcessor handler. If there is no match, it will recurse. type Router struct { processor DiscoveryQueryProcessor + recursor dnsRecursor domain string altDomain string datacenter string @@ -102,11 +111,15 @@ func NewRouter(cfg Config) (*Router, error) { altDomain := dns.CanonicalName(cfg.AgentConfig.DNSAltDomain) // TODO (v2-dns): need to figure out tenancy information here in a way that work for V2 and V1 + + logger := cfg.Logger.Named(logging.DNS) + router := &Router{ processor: cfg.Processor, + recursor: newRecursor(logger), domain: domain, altDomain: altDomain, - logger: cfg.Logger.Named(logging.DNS), + logger: logger, tokenFunc: cfg.TokenFunc, defaultEntMeta: cfg.EntMeta, } @@ -119,7 +132,7 @@ func NewRouter(cfg Config) (*Router, error) { // HandleRequest is used to process an individual DNS request. It returns a message in success or fail cases. func (r *Router) HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAddress net.Addr) *dns.Msg { - cfg := r.dynamicConfig.Load().(*RouterDynamicConfig) + configCtx := r.dynamicConfig.Load().(*RouterDynamicConfig) err := validateAndNormalizeRequest(req) if err != nil { @@ -127,43 +140,52 @@ func (r *Router) HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAdd if errors.Is(err, errInvalidQuestion) { return createRefusedResponse(req) } - return createServerFailureResponse(req, cfg, false) + return createServerFailureResponse(req, configCtx, false) } reqType, responseDomain, needRecurse := r.parseDomain(req) + if needRecurse && !canRecurse(configCtx) { + return createServerFailureResponse(req, configCtx, true) + } - if needRecurse && canRecurse(cfg) { - // TODO (v2-dns): handle recursion - r.logger.Error("recursion not implemented") - return createServerFailureResponse(req, cfg, false) + if needRecurse { + // This assumes `canRecurse(configCtx)` is true above + resp, err := r.recursor.handle(req, configCtx, remoteAddress) + if err != nil && !errors.Is(err, errRecursionFailed) { + r.logger.Error("unhandled error recursing DNS query", "error", err) + } + if err != nil { + return createServerFailureResponse(req, configCtx, true) + } + return resp } - results, err := r.getQueryResults(req, reqCtx, reqType, cfg) + results, err := r.getQueryResults(req, reqCtx, reqType, configCtx) if err != nil && errors.Is(err, errNameNotFound) { r.logger.Error("name not found", "name", req.Question[0].Name) - return createNameErrorResponse(req, cfg, responseDomain) + return createNameErrorResponse(req, configCtx, responseDomain) } if err != nil { r.logger.Error("error processing discovery query", "error", err) - return createServerFailureResponse(req, cfg, false) + return createServerFailureResponse(req, configCtx, false) } // This needs the question information because it affects the serialization format. // e.g., the Consul service has the same "results" for both NS and A/AAAA queries, but the serialization differs. - resp, err := r.serializeQueryResults(req, results, cfg, responseDomain) + resp, err := r.serializeQueryResults(req, results, configCtx, responseDomain) if err != nil { r.logger.Error("error serializing DNS results", "error", err) - return createServerFailureResponse(req, cfg, false) + return createServerFailureResponse(req, configCtx, false) } return resp } // getQueryResults returns a discovery.Result from a DNS message. -func (r *Router) getQueryResults(req *dns.Msg, reqCtx discovery.Context, reqType requestType, cfg *RouterDynamicConfig) ([]*discovery.Result, error) { +func (r *Router) getQueryResults(req *dns.Msg, reqCtx discovery.Context, reqType requestType, cfgCtx *RouterDynamicConfig) ([]*discovery.Result, error) { switch reqType { case requestTypeName: - query, err := buildQueryFromDNSMessage(req, r.domain, r.altDomain, cfg, r.defaultEntMeta) + query, err := buildQueryFromDNSMessage(req, r.domain, r.altDomain, cfgCtx, r.defaultEntMeta) if err != nil { r.logger.Error("error building discovery query from DNS request", "error", err) return nil, err @@ -197,29 +219,6 @@ func (r *Router) ReloadConfig(newCfg *config.RuntimeConfig) error { return nil } -// defaultAgentDNSRequestContext returns a default request context based on the agent's config. -func (r *Router) defaultAgentDNSRequestContext() discovery.Context { - return discovery.Context{ - Token: r.tokenFunc(), - // TODO (v2-dns): tenancy information; maybe we choose not to specify and use the default - // attached to the Router (from the agent's config) - } -} - -// validateAndNormalizeRequest validates the DNS request and normalizes the request name. -func validateAndNormalizeRequest(req *dns.Msg) error { - // like upstream miekg/dns, we require at least one question, - // but we will only answer the first. - if len(req.Question) == 0 { - return errInvalidQuestion - } - - // We mutate the request name to respond with the canonical name. - // This is Consul convention. - req.Question[0].Name = dns.CanonicalName(req.Question[0].Name) - return nil -} - // Request type is similar to miekg/dns.Type, but correlates to the different query processors we might need to invoke. type requestType string @@ -281,6 +280,29 @@ func (r *Router) serializeQueryResults(req *dns.Msg, results []*discovery.Result return resp, nil } +// defaultAgentDNSRequestContext returns a default request context based on the agent's config. +func (r *Router) defaultAgentDNSRequestContext() discovery.Context { + return discovery.Context{ + Token: r.tokenFunc(), + // TODO (v2-dns): tenancy information; maybe we choose not to specify and use the default + // attached to the Router (from the agent's config) + } +} + +// validateAndNormalizeRequest validates the DNS request and normalizes the request name. +func validateAndNormalizeRequest(req *dns.Msg) error { + // like upstream miekg/dns, we require at least one question, + // but we will only answer the first. + if len(req.Question) == 0 { + return errInvalidQuestion + } + + // We mutate the request name to respond with the canonical name. + // This is Consul convention. + req.Question[0].Name = dns.CanonicalName(req.Question[0].Name) + return nil +} + // stripSuffix strips off the suffixes that may have been added to the request name. func stripSuffix(target string) (string, bool) { enableFailover := false @@ -333,7 +355,14 @@ func getDynamicRouterConfig(conf *config.RuntimeConfig) (*RouterDynamicConfig, e // TODO (v2-dns): add service TTL recalculation - // TODO (v2-dns): add recursor address formatting + for _, r := range conf.DNSRecursors { + ra, err := formatRecursorAddress(r) + if err != nil { + return nil, fmt.Errorf("invalid recursor address: %w", err) + } + cfg.Recursors = append(cfg.Recursors, ra) + } + return cfg, nil } @@ -349,11 +378,67 @@ func createServerFailureResponse(req *dns.Msg, cfg *RouterDynamicConfig, recursi m.SetReply(req) m.Compress = !cfg.DisableCompression m.SetRcode(req, dns.RcodeServerFailure) - // TODO (2-dns): set EDNS m.RecursionAvailable = recursionAvailable + if edns := req.IsEdns0(); edns != nil { + setEDNS(req, m, true) + } return m } +// setEDNS is used to set the responses EDNS size headers and +// possibly the ECS headers as well if they were present in the +// original request +func setEDNS(request *dns.Msg, response *dns.Msg, ecsGlobal bool) { + edns := request.IsEdns0() + if edns == nil { + return + } + + // cannot just use the SetEdns0 function as we need to embed + // the ECS option as well + ednsResp := new(dns.OPT) + ednsResp.Hdr.Name = "." + ednsResp.Hdr.Rrtype = dns.TypeOPT + ednsResp.SetUDPSize(edns.UDPSize()) + + // Setup the ECS option if present + if subnet := ednsSubnetForRequest(request); subnet != nil { + subOp := new(dns.EDNS0_SUBNET) + subOp.Code = dns.EDNS0SUBNET + subOp.Family = subnet.Family + subOp.Address = subnet.Address + subOp.SourceNetmask = subnet.SourceNetmask + if c := response.Rcode; ecsGlobal || c == dns.RcodeNameError || c == dns.RcodeServerFailure || c == dns.RcodeRefused || c == dns.RcodeNotImplemented { + // reply is globally valid and should be cached accordingly + subOp.SourceScope = 0 + } else { + // reply is only valid for the subnet it was queried with + subOp.SourceScope = subnet.SourceNetmask + } + ednsResp.Option = append(ednsResp.Option, subOp) + } + + response.Extra = append(response.Extra, ednsResp) +} + +// ednsSubnetForRequest looks through the request to find any EDS subnet options +func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET { + // IsEdns0 returns the EDNS RR if present or nil otherwise + edns := req.IsEdns0() + + if edns == nil { + return nil + } + + for _, o := range edns.Option { + if subnet, ok := o.(*dns.EDNS0_SUBNET); ok { + return subnet + } + } + + return nil +} + // createRefusedResponse returns a REFUSED message. func createRefusedResponse(req *dns.Msg) *dns.Msg { // Return a REFUSED message diff --git a/agent/dns/router_query.go b/agent/dns/router_query.go index 9815e65aaa..5f46413681 100644 --- a/agent/dns/router_query.go +++ b/agent/dns/router_query.go @@ -14,10 +14,10 @@ import ( ) // buildQueryFromDNSMessage returns a discovery.Query from a DNS message. -func buildQueryFromDNSMessage(req *dns.Msg, domain, altDomain string, cfg *RouterDynamicConfig, defaultEntMeta acl.EnterpriseMeta) (*discovery.Query, error) { +func buildQueryFromDNSMessage(req *dns.Msg, domain, altDomain string, cfgCtx *RouterDynamicConfig, defaultEntMeta acl.EnterpriseMeta) (*discovery.Query, error) { queryType, queryParts, querySuffixes := getQueryTypePartsAndSuffixesFromDNSMessage(req, domain, altDomain) - locality, ok := ParseLocality(querySuffixes, defaultEntMeta, cfg.enterpriseDNSConfig) + locality, ok := ParseLocality(querySuffixes, defaultEntMeta, cfgCtx.enterpriseDNSConfig) if !ok { return nil, errors.New("invalid locality") } diff --git a/agent/dns/router_test.go b/agent/dns/router_test.go index 2f1dc133e6..32af2b6c17 100644 --- a/agent/dns/router_test.go +++ b/agent/dns/router_test.go @@ -4,6 +4,7 @@ package dns import ( + "errors" "net" "testing" "time" @@ -31,6 +32,7 @@ func Test_HandleRequest(t *testing.T) { name string agentConfig *config.RuntimeConfig // This will override the default test Router Config configureDataFetcher func(fetcher discovery.CatalogDataFetcher) + configureRecursor func(recursor dnsRecursor) mockProcessorError error request *dns.Msg requestContext *discovery.Context @@ -39,6 +41,191 @@ func Test_HandleRequest(t *testing.T) { } testCases := []testCase{ + // recursor queries + { + name: "recursors not configured, non-matching domain", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "google.com", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + // configureRecursor: call not expected. + response: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + Response: true, + Authoritative: false, + Rcode: dns.RcodeServerFailure, + RecursionAvailable: true, + }, + Compress: true, + Question: []dns.Question{ + { + Name: "google.com.", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + }, + { + name: "recursors configured, matching domain", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "google.com", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + agentConfig: &config.RuntimeConfig{ + DNSRecursors: []string{"8.8.8.8"}, + }, + configureRecursor: func(recursor dnsRecursor) { + resp := &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + Response: true, + Authoritative: true, + Rcode: dns.RcodeSuccess, + }, + Question: []dns.Question{ + { + Name: "google.com.", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + Answer: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{ + Name: "google.com.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: net.ParseIP("1.2.3.4"), + }, + }, + } + recursor.(*mockDnsRecursor).On("handle", + mock.Anything, mock.Anything, mock.Anything).Return(resp, nil) + }, + response: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + Response: true, + Authoritative: true, + Rcode: dns.RcodeSuccess, + }, + Question: []dns.Question{ + { + Name: "google.com.", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + Answer: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{ + Name: "google.com.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: net.ParseIP("1.2.3.4"), + }, + }, + }, + }, + { + name: "recursors configured, matching domain", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "google.com", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + agentConfig: &config.RuntimeConfig{ + DNSRecursors: []string{"8.8.8.8"}, + }, + configureRecursor: func(recursor dnsRecursor) { + recursor.(*mockDnsRecursor).On("handle", mock.Anything, mock.Anything, mock.Anything). + Return(nil, errRecursionFailed) + }, + response: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + Response: true, + Authoritative: false, + Rcode: dns.RcodeServerFailure, + RecursionAvailable: true, + }, + Compress: true, + Question: []dns.Question{ + { + Name: "google.com.", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + }, + { + name: "recursors configured, unhandled error calling recursors", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "google.com", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + agentConfig: &config.RuntimeConfig{ + DNSRecursors: []string{"8.8.8.8"}, + }, + configureRecursor: func(recursor dnsRecursor) { + err := errors.New("ahhhhh!!!!") + recursor.(*mockDnsRecursor).On("handle", mock.Anything, mock.Anything, mock.Anything). + Return(nil, err) + }, + response: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + Response: true, + Authoritative: false, + Rcode: dns.RcodeServerFailure, + RecursionAvailable: true, + }, + Compress: true, + Question: []dns.Question{ + { + Name: "google.com.", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + }, // addr queries { name: "test A 'addr.' query, ipv4 response", @@ -534,6 +721,12 @@ func Test_HandleRequest(t *testing.T) { router, err := NewRouter(cfg) require.NoError(t, err) + // Replace the recursor with a mock and configure + router.recursor = newMockDnsRecursor(t) + if tc.configureRecursor != nil { + tc.configureRecursor(router.recursor) + } + ctx := tc.requestContext if ctx == nil { ctx = &discovery.Context{} diff --git a/agent/dns_service_lookup_test.go b/agent/dns_service_lookup_test.go index a40d00fa12..9d0683401e 100644 --- a/agent/dns_service_lookup_test.go +++ b/agent/dns_service_lookup_test.go @@ -3228,6 +3228,7 @@ func checkDNSService( udp_answer_limit = `+fmt.Sprintf("%d", aRecordLimit)+` } `+experimentsHCL) + defer a.Shutdown() testrpc.WaitForTestAgent(t, a.RPC, "dc1") choices := perfectlyRandomChoices(generateNumNodes, pctNodesWithIPv6) diff --git a/agent/dns_test.go b/agent/dns_test.go index 8f584172cd..bfd45c70ac 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -33,7 +33,6 @@ import ( "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/consul" "github.com/hashicorp/consul/agent/structs" - libdns "github.com/hashicorp/consul/internal/dnsutil" "github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/testrpc" @@ -128,6 +127,7 @@ func getVersionHCL(enableV2 bool) map[string]string { return versions } +// Copied to agent/dns/recursor_test.go func TestRecursorAddr(t *testing.T) { t.Parallel() addr, err := recursorAddr("8.8.8.8") @@ -230,7 +230,7 @@ func TestDNS_EmptyAltDomain(t *testing.T) { } t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, experimentsHCL) defer a.Shutdown() @@ -266,7 +266,7 @@ func TestDNSCycleRecursorCheck(t *testing.T) { }, }) defer server2.Shutdown() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { // Mock the agent startup with the necessary configs agent := NewTestAgent(t, @@ -308,7 +308,7 @@ func TestDNSCycleRecursorCheckAllFail(t *testing.T) { MsgHdr: dns.MsgHdr{Rcode: dns.RcodeRefused}, }) defer server3.Shutdown() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { // Mock the agent startup with the necessary configs agent := NewTestAgent(t, @@ -1491,7 +1491,7 @@ func TestDNS_Recurse(t *testing.T) { }) defer recursor.Shutdown() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, ` recursors = ["`+recursor.Addr+`"] @@ -1531,7 +1531,7 @@ func TestDNS_Recurse_Truncation(t *testing.T) { }) defer recursor.Shutdown() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, ` recursors = ["`+recursor.Addr+`"] @@ -1580,7 +1580,7 @@ func TestDNS_RecursorTimeout(t *testing.T) { } defer resolver.Close() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, ` recursors = ["`+resolver.LocalAddr().String()+`"] // host must cause a connection|read|write timeout @@ -3497,7 +3497,7 @@ func TestDNS_Compression_Recurse(t *testing.T) { }) defer recursor.Shutdown() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, ` @@ -3543,29 +3543,6 @@ func TestDNS_Compression_Recurse(t *testing.T) { } } -func TestDNSInvalidRegex(t *testing.T) { - tests := []struct { - desc string - in string - invalid bool - }{ - {"Valid Hostname", "testnode", false}, - {"Valid Hostname", "test-node", false}, - {"Invalid Hostname with special chars", "test#$$!node", true}, - {"Invalid Hostname with special chars in the end", "testnode%^", true}, - {"Whitespace", " ", true}, - {"Only special chars", "./$", true}, - } - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - if got, want := libdns.InvalidNameRe.MatchString(test.in), test.invalid; got != want { - t.Fatalf("Expected %v to return %v", test.in, want) - } - }) - - } -} - func TestDNS_V1ConfigReload(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") diff --git a/internal/dnsutil/dns_test.go b/internal/dnsutil/dns_test.go index b4010ef316..608832a69c 100644 --- a/internal/dnsutil/dns_test.go +++ b/internal/dnsutil/dns_test.go @@ -50,3 +50,26 @@ func TestValidLabel(t *testing.T) { }) } } + +func TestDNSInvalidRegex(t *testing.T) { + tests := []struct { + desc string + in string + invalid bool + }{ + {"Valid Hostname", "testnode", false}, + {"Valid Hostname", "test-node", false}, + {"Invalid Hostname with special chars", "test#$$!node", true}, + {"Invalid Hostname with special chars in the end", "testnode%^", true}, + {"Whitespace", " ", true}, + {"Only special chars", "./$", true}, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + if got, want := InvalidNameRe.MatchString(test.in), test.invalid; got != want { + t.Fatalf("Expected %v to return %v", test.in, want) + } + }) + + } +}