From e309d51a5b1231fa97919831e691a5353b8456f5 Mon Sep 17 00:00:00 2001 From: hc-github-team-consul-core Date: Thu, 8 Feb 2024 00:20:09 -0500 Subject: [PATCH] Backport of DNS v2 Multiple fixes. into release/1.18.x (#20530) * no-op commit due to failed cherry-picking * DNS v2 Multiple fixes. (#20525) * DNS v2 Multiple fixes. * add license header * get rid of DefaultIntentionPolicy change that was not supposed to be there. --------- Co-authored-by: temp Co-authored-by: John Murret --- agent/config/runtime.go | 2 +- agent/discovery/query_fetcher_v1.go | 26 ++- agent/discovery/query_fetcher_v1_ce.go | 2 +- agent/dns/router.go | 161 ++++++++++----- agent/dns/router_ce.go | 5 + agent/dns/router_query.go | 5 +- agent/dns/router_response.go | 259 +++++++++++++++++++++++++ agent/dns/router_test.go | 15 +- agent/dns_node_lookup_test.go | 8 - agent/dns_service_lookup_test.go | 20 +- agent/dns_test.go | 53 ++--- 11 files changed, 438 insertions(+), 118 deletions(-) create mode 100644 agent/dns/router_response.go diff --git a/agent/config/runtime.go b/agent/config/runtime.go index 0ec211d62e..9e4e4ace00 100644 --- a/agent/config/runtime.go +++ b/agent/config/runtime.go @@ -272,7 +272,7 @@ type RuntimeConfig struct { // Records returned in the ANSWER section of a DNS response for UDP // responses without EDNS support (limited to 512 bytes). // This parameter is deprecated, if you want to limit the number of - // records returned by A or AAAA questions, please use DNSARecordLimit + // records returned by A or AAAA questions, please use TestDNS_ServiceLookup_Randomize // instead. // // hcl: dns_config { udp_answer_limit = int } diff --git a/agent/discovery/query_fetcher_v1.go b/agent/discovery/query_fetcher_v1.go index c3146a48ac..f588dc2662 100644 --- a/agent/discovery/query_fetcher_v1.go +++ b/agent/discovery/query_fetcher_v1.go @@ -115,7 +115,7 @@ func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, e // If we have no out.NodeServices.Nodeaddress, return not found! if out.NodeServices == nil { - return nil, errors.New("no nodes found") + return nil, ErrNotFound } results := make([]*Result, 0, 1) @@ -302,7 +302,11 @@ func (f *V1DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*R out, err := f.executePreparedQuery(cfg, args) if err != nil { - return nil, err + // errors.Is() doesn't work with errors.New() so we need to check the error message. + if err.Error() == structs.ErrQueryNotFound.Error() { + err = ErrNotFound + } + return nil, ECSNotGlobalError{err} } // (v2-dns) TODO: (v2-dns) get TTLS working. They come from the database so not having @@ -337,12 +341,12 @@ func (f *V1DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*R // If we have no nodes, return not found! if len(out.Nodes) == 0 { - return nil, ErrNoData + return nil, ECSNotGlobalError{ErrNoData} } // Perform a random shuffle out.Nodes.Shuffle() - return f.buildResultsFromServiceNodes(out.Nodes), nil + return f.buildResultsFromServiceNodes(out.Nodes, req), ECSNotGlobalError{} } // executePreparedQuery is used to execute a PreparedQuery against the Consul catalog. @@ -399,10 +403,16 @@ func (f *V1DataFetcher) ValidateRequest(_ Context, req *QueryPayload) error { } // buildResultsFromServiceNodes builds a list of results from a list of nodes. -func (f *V1DataFetcher) buildResultsFromServiceNodes(nodes []structs.CheckServiceNode) []*Result { - results := make([]*Result, 0) - for _, n := range nodes { +func (f *V1DataFetcher) buildResultsFromServiceNodes(nodes []structs.CheckServiceNode, req *QueryPayload) []*Result { + // Convert the service endpoints to results up to the limit + limit := req.Limit + if len(nodes) < limit || limit == 0 { + limit = len(nodes) + } + results := make([]*Result, 0, limit) + for idx := 0; idx < limit; idx++ { + n := nodes[idx] results = append(results, &Result{ Service: &Location{ Name: n.Service.Service, @@ -534,7 +544,7 @@ func (f *V1DataFetcher) fetchServiceBasedOnTenancy(ctx Context, req *QueryPayloa // Perform a random shuffle out.Nodes.Shuffle() - return f.buildResultsFromServiceNodes(out.Nodes), nil + return f.buildResultsFromServiceNodes(out.Nodes, req), nil } // findWeight returns the weight of a service node. diff --git a/agent/discovery/query_fetcher_v1_ce.go b/agent/discovery/query_fetcher_v1_ce.go index 2bb2a774dd..03be837bfa 100644 --- a/agent/discovery/query_fetcher_v1_ce.go +++ b/agent/discovery/query_fetcher_v1_ce.go @@ -18,7 +18,7 @@ func (f *V1DataFetcher) NormalizeRequest(req *QueryPayload) { } func validateEnterpriseTenancy(req QueryTenancy) error { - if req.Namespace != "" || req.Partition != "" { + if req.Namespace != "" || req.Partition != acl.DefaultPartitionName { return ErrNotSupported } return nil diff --git a/agent/dns/router.go b/agent/dns/router.go index 405af9ed2a..0f9de46e29 100644 --- a/agent/dns/router.go +++ b/agent/dns/router.go @@ -7,6 +7,7 @@ import ( "encoding/hex" "errors" "fmt" + "github.com/hashicorp/consul/acl" "net" "regexp" "strings" @@ -22,7 +23,6 @@ import ( "github.com/hashicorp/consul/agent/discovery" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/internal/dnsutil" - "github.com/hashicorp/consul/internal/resource" "github.com/hashicorp/consul/logging" ) @@ -42,7 +42,6 @@ var ( errInvalidQuestion = fmt.Errorf("invalid question") errNameNotFound = fmt.Errorf("name not found") errNotImplemented = fmt.Errorf("not implemented") - errQueryNotFound = fmt.Errorf("query not found") errRecursionFailed = fmt.Errorf("recursion failed") trailingSpacesRE = regexp.MustCompile(" +$") @@ -147,6 +146,14 @@ func (r *Router) HandleRequest(req *dns.Msg, reqCtx Context, remoteAddress net.A return r.handleRequestRecursively(req, reqCtx, remoteAddress, maxRecursionLevelDefault) } +// getErrorFromECSNotGlobalError returns the underlying error from an ECSNotGlobalError, if it exists. +func getErrorFromECSNotGlobalError(err error) error { + if errors.Is(err, discovery.ErrECSNotGlobal) { + return err.(discovery.ECSNotGlobalError).Unwrap() + } + return err +} + // handleRequestRecursively is used to process an individual DNS request. It will recurse as needed // a maximum number of times and returns a message in success or fail cases. func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx Context, @@ -190,35 +197,47 @@ func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx Context, reqType := parseRequestType(req) results, query, err := r.getQueryResults(req, reqCtx, reqType, qName, remoteAddress) - switch { - case errors.Is(err, errNameNotFound): - r.logger.Error("name not found", "name", qName) - - ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal) - return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal) - case errors.Is(err, errNotImplemented): - r.logger.Error("query not implemented", "name", qName, "type", dns.Type(req.Question[0].Qtype).String()) - ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal) - return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNotImplemented, ecsGlobal) - case errors.Is(err, discovery.ErrNotSupported): - r.logger.Debug("query name syntax not supported", "name", req.Question[0].Name) + // incase of the wrapped ECSNotGlobalError, extract the error from it. + isECSGlobal := !errors.Is(err, discovery.ErrECSNotGlobal) + err = getErrorFromECSNotGlobalError(err) - ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal) - return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal) - case errors.Is(err, discovery.ErrNotFound): - r.logger.Debug("query name not found", "name", req.Question[0].Name) - - ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal) - return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal) - case errors.Is(err, discovery.ErrNoData): - r.logger.Debug("no data available", "name", qName) - - ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal) - return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeSuccess, ecsGlobal) - case err != nil: - r.logger.Error("error processing discovery query", "error", err) - return createServerFailureResponse(req, configCtx, canRecurse(configCtx)) + if err != nil { + switch { + case errors.Is(err, errInvalidQuestion): + r.logger.Error("invalid question", "name", qName) + + ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal) + return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal) + case errors.Is(err, errNameNotFound): + r.logger.Error("name not found", "name", qName) + + ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal) + return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal) + case errors.Is(err, errNotImplemented): + r.logger.Error("query not implemented", "name", qName, "type", dns.Type(req.Question[0].Qtype).String()) + + ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal) + return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNotImplemented, ecsGlobal) + case errors.Is(err, discovery.ErrNotSupported): + r.logger.Debug("query name syntax not supported", "name", req.Question[0].Name) + + ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal) + return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal) + case errors.Is(err, discovery.ErrNotFound): + r.logger.Debug("query name not found", "name", req.Question[0].Name) + + ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal) + return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeNameError, ecsGlobal) + case errors.Is(err, discovery.ErrNoData): + r.logger.Debug("no data available", "name", qName) + + ecsGlobal := !errors.Is(err, discovery.ErrECSNotGlobal) + return createAuthoritativeResponse(req, configCtx, responseDomain, dns.RcodeSuccess, ecsGlobal) + default: + r.logger.Error("error processing discovery query", "error", err) + return createServerFailureResponse(req, configCtx, canRecurse(configCtx)) + } } // This needs the question information because it affects the serialization format. @@ -228,6 +247,16 @@ func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx Context, r.logger.Error("error serializing DNS results", "error", err) return createServerFailureResponse(req, configCtx, false) } + + // Switch to TCP if the client is + network := "udp" + if _, ok := remoteAddress.(*net.TCPAddr); ok { + network = "tcp" + } + + trimDNSResponse(configCtx, network, req, resp, r.logger) + + setEDNS(req, resp, isECSGlobal) return resp } @@ -289,7 +318,7 @@ func (r *Router) getQueryResults(req *dns.Msg, reqCtx Context, reqType requestTy // We don't want the query processors default partition to be used. // This is a small hack because for V1 CE, this is not the correct default partition name, but we // need to add something to disambiguate the empty field. - Partition: resource.DefaultPartitionName, + Partition: acl.DefaultPartitionName, //NOTE: note this won't work if we ever have V2 client agents }, Limit: 3, }, @@ -304,18 +333,12 @@ func (r *Router) getQueryResults(req *dns.Msg, reqCtx Context, reqType requestTy return nil, query, err } results, err := r.processor.QueryByName(query, discovery.Context{Token: reqCtx.Token}) - if err != nil { - r.logger.Error("error processing discovery query", "error", err) - switch err.Error() { - case errNameNotFound.Error(): - return nil, query, errNameNotFound - case errQueryNotFound.Error(): - return nil, query, errQueryNotFound - } + if getErrorFromECSNotGlobalError(err) != nil { + r.logger.Error("error processing discovery query", "error", err) return nil, query, err } - return results, query, nil + return results, query, err case requestTypeIP: ip := dnsutil.IPFromARPA(qName) if ip == nil { @@ -332,7 +355,9 @@ func (r *Router) getQueryResults(req *dns.Msg, reqCtx Context, reqType requestTy } return results, nil, nil } - return nil, nil, errors.New("invalid request type") + + r.logger.Error("error parsing discovery query type", "requestType", reqType) + return nil, nil, errInvalidQuestion } // ServeDNS implements the miekg/dns.Handler interface. @@ -452,8 +477,30 @@ func (r *Router) serializeQueryResults(req *dns.Msg, reqCtx Context, resp.Extra = append(resp.Extra, ex...) resp.Ns = append(resp.Ns, ns...) } - case qType == dns.TypeSRV, reqType == requestTypeAddress: + case reqType == requestTypeAddress: + for _, result := range results { + ans, ex, ns := r.getAnswerExtraAndNs(result, req, reqCtx, query, cfg, responseDomain, remoteAddress, maxRecursionLevel) + resp.Answer = append(resp.Answer, ans...) + resp.Extra = append(resp.Extra, ex...) + resp.Ns = append(resp.Ns, ns...) + } + case qType == dns.TypeSRV: + handled := make(map[string]struct{}) for _, result := range results { + // Avoid duplicate entries, possible if a node has + // the same service the same port, etc. + + // The datacenter should be empty during translation if it is a peering lookup. + // This should be fine because we should always prefer the WAN address. + //serviceAddress := d.agent.TranslateServiceAddress(lookup.Datacenter, node.Service.Address, node.Service.TaggedAddresses, TranslateAddressAcceptAny) + //servicePort := d.agent.TranslateServicePort(lookup.Datacenter, node.Service.Port, node.Service.TaggedAddresses) + //tuple := fmt.Sprintf("%s:%s:%d", node.Node.Node, serviceAddress, servicePort) + + tuple := fmt.Sprintf("%s:%s:%d", result.Node.Name, result.Service.Address, result.PortNumber) + if _, ok := handled[tuple]; ok { + continue + } + handled[tuple] = struct{}{} ans, ex, ns := r.getAnswerExtraAndNs(result, req, reqCtx, query, cfg, responseDomain, remoteAddress, maxRecursionLevel) resp.Answer = append(resp.Answer, ans...) resp.Extra = append(resp.Extra, ex...) @@ -695,6 +742,7 @@ func createServerFailureResponse(req *dns.Msg, cfg *RouterDynamicConfig, recursi if edns := req.IsEdns0(); edns != nil { setEDNS(req, m, true) } + return m } @@ -844,7 +892,12 @@ func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, req answer = append(answer, ptr) case qType == dns.TypeNS: // TODO (v2-dns): fqdn in V1 has the datacenter included, this would need to be added to discovery.Result - fqdn := canonicalNameForResult(result.Type, result.Node.Name, domain, result.Tenancy, result.PortName) + resultType := result.Type + target := result.Node.Name + if parseRequestType(req) == requestTypeConsul && resultType == discovery.ResultTypeService { + resultType = discovery.ResultTypeNode + } + fqdn := canonicalNameForResult(resultType, target, domain, result.Tenancy, result.PortName) extraRecord := makeIPBasedRecord(fqdn, nodeAddress, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported answer = append(answer, makeNSRecord(domain, fqdn, ttl)) @@ -871,7 +924,7 @@ func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, req extra = append(extra, e...) } - a, e := getAnswerAndExtraTXT(req, cfg, qName, result, ttl, domain) + a, e := getAnswerAndExtraTXT(req, cfg, qName, result, ttl, domain, query) answer = append(answer, a...) extra = append(extra, e...) return @@ -954,7 +1007,10 @@ func (r *Router) getAnswerExtrasForAddressAndTarget(nodeAddress *dnsAddress, ser // getAnswerAndExtraTXT determines whether a TXT needs to be create and then // returns the TXT record in the answer or extra depending on the question type. func getAnswerAndExtraTXT(req *dns.Msg, cfg *RouterDynamicConfig, qName string, - result *discovery.Result, ttl uint32, domain string) (answer []dns.RR, extra []dns.RR) { + result *discovery.Result, ttl uint32, domain string, query *discovery.Query) (answer []dns.RR, extra []dns.RR) { + if !shouldAppendTXTRecord(query, cfg, req) { + return + } recordHeaderName := qName serviceAddress := newDNSAddress("") if result.Service != nil { @@ -989,6 +1045,23 @@ func getAnswerAndExtraTXT(req *dns.Msg, cfg *RouterDynamicConfig, qName string, return answer, extra } +// shouldAppendTXTRecord determines whether a TXT record should be appended to the response. +func shouldAppendTXTRecord(query *discovery.Query, cfg *RouterDynamicConfig, req *dns.Msg) bool { + qType := req.Question[0].Qtype + switch { + // Node records + case query != nil && query.QueryType == discovery.QueryTypeNode && (cfg.NodeMetaTXT || qType == dns.TypeANY || qType == dns.TypeTXT): + return true + // Service records + case query != nil && query.QueryType == discovery.QueryTypeService && cfg.NodeMetaTXT && qType == dns.TypeSRV: + return true + // Prepared query records + case query != nil && query.QueryType == discovery.QueryTypePreparedQuery && cfg.NodeMetaTXT && qType == dns.TypeSRV: + return true + } + return false +} + // getAnswerExtrasForIP creates the dns answer and extra from IP dnsAddress pairs. func getAnswerExtrasForIP(name string, addr *dnsAddress, question dns.Question, reqType requestType, result *discovery.Result, ttl uint32, _ string) (answer []dns.RR, extra []dns.RR) { diff --git a/agent/dns/router_ce.go b/agent/dns/router_ce.go index 67cab00490..9380221e8b 100644 --- a/agent/dns/router_ce.go +++ b/agent/dns/router_ce.go @@ -36,3 +36,8 @@ func canonicalNameForResult(resultType discovery.ResultType, target, domain stri } return "" } + +// getDefaultPartitionName returns the default partition name. +func getDefaultPartitionName() string { + return "" +} diff --git a/agent/dns/router_query.go b/agent/dns/router_query.go index 6576b8724b..420bba1d09 100644 --- a/agent/dns/router_query.go +++ b/agent/dns/router_query.go @@ -26,9 +26,12 @@ func buildQueryFromDNSMessage(req *dns.Msg, reqCtx Context, domain, altDomain st portName := parsePort(queryParts) - if queryType == discovery.QueryTypeWorkload && req.Question[0].Qtype == dns.TypeSRV { + switch { + case queryType == discovery.QueryTypeWorkload && req.Question[0].Qtype == dns.TypeSRV: // Currently we do not support SRV records for workloads return nil, errNotImplemented + case queryType == discovery.QueryTypeInvalid, name == "": + return nil, errInvalidQuestion } return &discovery.Query{ diff --git a/agent/dns/router_response.go b/agent/dns/router_response.go new file mode 100644 index 0000000000..d2000745c8 --- /dev/null +++ b/agent/dns/router_response.go @@ -0,0 +1,259 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 +package dns + +import ( + "fmt" + "github.com/hashicorp/consul/lib" + "github.com/hashicorp/go-hclog" + "github.com/miekg/dns" + "math" + "strings" +) + +const ( + // UDP can fit ~25 A records in a 512B response, and ~14 AAAA + // records. Limit further to prevent unintentional configuration + // abuse that would have a negative effect on application response + // times. + maxUDPAnswerLimit = 8 + + defaultMaxUDPSize = 512 + + // If a consumer sets a buffer size greater than this amount we will default it down + // to this amount to ensure that consul does respond. Previously if consumer had a larger buffer + // size than 65535 - 60 bytes (maximim 60 bytes for IP header. UDP header will be offset in the + // trimUDP call) consul would fail to respond and the consumer timesout + // the request. + maxUDPDatagramSize = math.MaxUint16 - 68 +) + +// trimDNSResponse will trim the response for UDP and TCP +func trimDNSResponse(cfg *RouterDynamicConfig, network string, req, resp *dns.Msg, logger hclog.Logger) { + var trimmed bool + originalSize := resp.Len() + originalNumRecords := len(resp.Answer) + if network != "tcp" { + trimmed = trimUDPResponse(req, resp, cfg.UDPAnswerLimit) + } else { + trimmed = trimTCPResponse(req, resp) + } + // Flag that there are more records to return in the UDP response + if trimmed { + if cfg.EnableTruncate { + resp.Truncated = true + } + logger.Debug("DNS response too large, truncated", + "protocol", network, + "question", req.Question, + "records", fmt.Sprintf("%d/%d", len(resp.Answer), originalNumRecords), + "size", fmt.Sprintf("%d/%d", resp.Len(), originalSize), + ) + } +} + +// trimTCPResponse limit the MaximumSize of messages to 64k as it is the limit +// of DNS responses +func trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { + hasExtra := len(resp.Extra) > 0 + // There is some overhead, 65535 does not work + maxSize := 65523 // 64k - 12 bytes DNS raw overhead + + // We avoid some function calls and allocations by only handling the + // extra data when necessary. + var index map[string]dns.RR + + // It is not possible to return more than 4k records even with compression + // Since we are performing binary search it is not a big deal, but it + // improves a bit performance, even with binary search + truncateAt := 4096 + if req.Question[0].Qtype == dns.TypeSRV { + // More than 1024 SRV records do not fit in 64k + truncateAt = 1024 + } + if len(resp.Answer) > truncateAt { + resp.Answer = resp.Answer[:truncateAt] + } + if hasExtra { + index = make(map[string]dns.RR, len(resp.Extra)) + indexRRs(resp.Extra, index) + } + truncated := false + + // This enforces the given limit on 64k, the max limit for DNS messages + for len(resp.Answer) > 1 && resp.Len() > maxSize { + truncated = true + // first try to remove the NS section may be it will truncate enough + if len(resp.Ns) != 0 { + resp.Ns = []dns.RR{} + } + // More than 100 bytes, find with a binary search + if resp.Len()-maxSize > 100 { + bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra) + resp.Answer = resp.Answer[:bestIndex] + } else { + resp.Answer = resp.Answer[:len(resp.Answer)-1] + } + if hasExtra { + syncExtra(index, resp) + } + } + + return truncated +} + +// trimUDPResponse makes sure a UDP response is not longer than allowed by RFC +// 1035. Enforce an arbitrary limit that can be further ratcheted down by +// config, and then make sure the response doesn't exceed 512 bytes. Any extra +// records will be trimmed along with answers. +func trimUDPResponse(req, resp *dns.Msg, udpAnswerLimit int) (trimmed bool) { + numAnswers := len(resp.Answer) + hasExtra := len(resp.Extra) > 0 + maxSize := defaultMaxUDPSize + + // Update to the maximum edns size + if edns := req.IsEdns0(); edns != nil { + if size := edns.UDPSize(); size > uint16(maxSize) { + maxSize = int(size) + } + } + // Overriding maxSize as the maxSize cannot be larger than the + // maxUDPDatagram size. Reliability guarantees disappear > than this amount. + if maxSize > maxUDPDatagramSize { + maxSize = maxUDPDatagramSize + } + + // We avoid some function calls and allocations by only handling the + // extra data when necessary. + var index map[string]dns.RR + if hasExtra { + index = make(map[string]dns.RR, len(resp.Extra)) + indexRRs(resp.Extra, index) + } + + // This cuts UDP responses to a useful but limited number of responses. + maxAnswers := lib.MinInt(maxUDPAnswerLimit, udpAnswerLimit) + compress := resp.Compress + if maxSize == defaultMaxUDPSize && numAnswers > maxAnswers { + // We disable computation of Len ONLY for non-eDNS request (512 bytes) + resp.Compress = false + resp.Answer = resp.Answer[:maxAnswers] + if hasExtra { + syncExtra(index, resp) + } + } + if maxSize == defaultMaxUDPSize && numAnswers > maxAnswers { + // We disable computation of Len ONLY for non-eDNS request (512 bytes) + resp.Compress = false + resp.Answer = resp.Answer[:maxAnswers] + if hasExtra { + syncExtra(index, resp) + } + } + + // This enforces the given limit on the number bytes. The default is 512 as + // per the RFC, but EDNS0 allows for the user to specify larger sizes. Note + // that we temporarily switch to uncompressed so that we limit to a response + // that will not exceed 512 bytes uncompressed, which is more conservative and + // will allow our responses to be compliant even if some downstream server + // uncompresses them. + // Even when size is too big for one single record, try to send it anyway + // (useful for 512 bytes messages). 8 is removed from maxSize to ensure that we account + // for the udp header (8 bytes). + for len(resp.Answer) > 1 && resp.Len() > maxSize-8 { + // first try to remove the NS section may be it will truncate enough + if len(resp.Ns) != 0 { + resp.Ns = []dns.RR{} + } + // More than 100 bytes, find with a binary search + if resp.Len()-maxSize > 100 { + bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra) + resp.Answer = resp.Answer[:bestIndex] + } else { + resp.Answer = resp.Answer[:len(resp.Answer)-1] + } + if hasExtra { + syncExtra(index, resp) + } + } + // For 512 non-eDNS responses, while we compute size non-compressed, + // we send result compressed + resp.Compress = compress + return len(resp.Answer) < numAnswers +} + +// syncExtra takes a DNS response message and sets the extra data to the most +// minimal set needed to cover the answer data. A pre-made index of RRs is given +// so that can be re-used between calls. This assumes that the extra data is +// only used to provide info for SRV records. If that's not the case, then this +// will wipe out any additional data. +func syncExtra(index map[string]dns.RR, resp *dns.Msg) { + extra := make([]dns.RR, 0, len(resp.Answer)) + resolved := make(map[string]struct{}, len(resp.Answer)) + for _, ansRR := range resp.Answer { + srv, ok := ansRR.(*dns.SRV) + if !ok { + continue + } + + // Note that we always use lower case when using the index so + // that compares are not case-sensitive. We don't alter the actual + // RRs we add into the extra section, however. + target := strings.ToLower(srv.Target) + + RESOLVE: + if _, ok := resolved[target]; ok { + continue + } + resolved[target] = struct{}{} + + extraRR, ok := index[target] + if ok { + extra = append(extra, extraRR) + if cname, ok := extraRR.(*dns.CNAME); ok { + target = strings.ToLower(cname.Target) + goto RESOLVE + } + } + } + resp.Extra = extra +} + +// dnsBinaryTruncate find the optimal number of records using a fast binary search and return +// it in order to return a DNS answer lower than maxSize parameter. +func dnsBinaryTruncate(resp *dns.Msg, maxSize int, index map[string]dns.RR, hasExtra bool) int { + originalAnswser := resp.Answer + startIndex := 0 + endIndex := len(resp.Answer) + 1 + for endIndex-startIndex > 1 { + median := startIndex + (endIndex-startIndex)/2 + + resp.Answer = originalAnswser[:median] + if hasExtra { + syncExtra(index, resp) + } + aLen := resp.Len() + if aLen <= maxSize { + if maxSize-aLen < 10 { + // We are good, increasing will go out of bounds + return median + } + startIndex = median + } else { + endIndex = median + } + } + return startIndex +} + +// indexRRs populates a map which indexes a given list of RRs by name. NOTE that +// the names are all squashed to lower case so we can perform case-insensitive +// lookups; the RRs are not modified. +func indexRRs(rrs []dns.RR, index map[string]dns.RR) { + for _, rr := range rrs { + name := strings.ToLower(rr.Header().Name) + if _, ok := index[name]; !ok { + index[name] = rr + } + } +} diff --git a/agent/dns/router_test.go b/agent/dns/router_test.go index 220ae27f38..a14b0cac5c 100644 --- a/agent/dns/router_test.go +++ b/agent/dns/router_test.go @@ -89,7 +89,8 @@ func Test_HandleRequest(t *testing.T) { }, }, agentConfig: &config.RuntimeConfig{ - DNSRecursors: []string{"8.8.8.8"}, + DNSRecursors: []string{"8.8.8.8"}, + DNSUDPAnswerLimit: maxUDPAnswerLimit, }, configureRecursor: func(recursor dnsRecursor) { resp := &dns.Msg{ @@ -161,7 +162,8 @@ func Test_HandleRequest(t *testing.T) { }, }, agentConfig: &config.RuntimeConfig{ - DNSRecursors: []string{"8.8.8.8"}, + DNSRecursors: []string{"8.8.8.8"}, + DNSUDPAnswerLimit: maxUDPAnswerLimit, }, configureRecursor: func(recursor dnsRecursor) { recursor.(*mockDnsRecursor).On("handle", mock.Anything, mock.Anything, mock.Anything). @@ -200,7 +202,8 @@ func Test_HandleRequest(t *testing.T) { }, }, agentConfig: &config.RuntimeConfig{ - DNSRecursors: []string{"8.8.8.8"}, + DNSRecursors: []string{"8.8.8.8"}, + DNSUDPAnswerLimit: maxUDPAnswerLimit, }, configureRecursor: func(recursor dnsRecursor) { err := errors.New("ahhhhh!!!!") @@ -240,7 +243,8 @@ func Test_HandleRequest(t *testing.T) { }, }, agentConfig: &config.RuntimeConfig{ - DNSRecursors: []string{"8.8.8.8"}, + DNSRecursors: []string{"8.8.8.8"}, + DNSUDPAnswerLimit: maxUDPAnswerLimit, }, configureRecursor: func(recursor dnsRecursor) { // this response is modeled after `dig .` @@ -934,6 +938,7 @@ func Test_HandleRequest(t *testing.T) { Expire: 3, Minttl: 4, }, + DNSUDPAnswerLimit: maxUDPAnswerLimit, }, configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { fetcher.(*discovery.MockCatalogDataFetcher). @@ -1151,6 +1156,7 @@ func Test_HandleRequest(t *testing.T) { Expire: 3, Minttl: 4, }, + DNSUDPAnswerLimit: maxUDPAnswerLimit, }, configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { fetcher.(*discovery.MockCatalogDataFetcher). @@ -1881,6 +1887,7 @@ func buildDNSConfig(agentConfig *config.RuntimeConfig, cdf discovery.CatalogData Expire: 3, Minttl: 4, }, + DNSUDPAnswerLimit: maxUDPAnswerLimit, }, EntMeta: acl.EnterpriseMeta{}, Logger: hclog.NewNullLogger(), diff --git a/agent/dns_node_lookup_test.go b/agent/dns_node_lookup_test.go index 1b3ab1ffd6..8bc7ccb965 100644 --- a/agent/dns_node_lookup_test.go +++ b/agent/dns_node_lookup_test.go @@ -258,11 +258,6 @@ func TestDNS_NodeLookup_AAAA(t *testing.T) { } } -// TODO (v2-dns): NET-7631 - Implement external CNAME references -// Failing on answer assertion. some CNAMEs are not getting created -// and the record type on the AAAA record is incorrect. -// External services do not appear to be working properly here -// and in the service lookup tests. func TestDNS_NodeLookup_CNAME(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -581,9 +576,6 @@ func TestDNS_NodeLookup_A_SuppressTXT(t *testing.T) { } } -// TODO (v2-dns): NET-7631 - Implement external CNAME references -// Failing on "Should have the CNAME record + a few A records" comment -// External services do not appear to be working properly here either. func TestDNS_NodeLookup_TTL(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") diff --git a/agent/dns_service_lookup_test.go b/agent/dns_service_lookup_test.go index a15d1ed03f..cfcbdf8133 100644 --- a/agent/dns_service_lookup_test.go +++ b/agent/dns_service_lookup_test.go @@ -461,14 +461,12 @@ func TestDNS_ServiceLookupMultiAddrNoCNAME(t *testing.T) { } } -// TODO (v2-dns): NET-7640 - NS Record not populate on some invalid service / prepared query lookups. func TestDNS_ServiceLookup(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } - t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, experimentsHCL) defer a.Shutdown() @@ -2274,14 +2272,12 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) { } } -// TODO (v2-dns): NET-7641 - Service lookups not properly de-duping SRV records func TestDNS_ServiceLookup_Dedup_SRV(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } - t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, experimentsHCL) defer a.Shutdown() @@ -2853,14 +2849,12 @@ func TestDNS_ServiceLookup_OnlyPassing(t *testing.T) { } } -// TODO (v2-dns): NET-7635 - Fix dns: overflowing header size in tests func TestDNS_ServiceLookup_Randomize(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } - t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, experimentsHCL) defer a.Shutdown() @@ -2953,14 +2947,12 @@ func TestDNS_ServiceLookup_Randomize(t *testing.T) { } } -// TODO (v2-dns): NET-7635 - Fix dns: overflowing header size in tests func TestDNS_ServiceLookup_Truncate(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } - t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, ` dns_config { @@ -3412,14 +3404,12 @@ func TestDNS_ServiceLookup_ARecordLimits(t *testing.T) { } } -// TODO (v2-dns): NET-7633 - implement answer limits. func TestDNS_ServiceLookup_AnswerLimits(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } - t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { // Build a matrix of config parameters (udpAnswerLimit), and the diff --git a/agent/dns_test.go b/agent/dns_test.go index ffa50760a8..f1ecbafc2f 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -24,11 +24,12 @@ import ( "testing" "time" - "github.com/hashicorp/serf/coordinate" "github.com/miekg/dns" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" + "github.com/hashicorp/serf/coordinate" + "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/consul" @@ -128,7 +129,7 @@ func getVersionHCL(enableV2 bool) map[string]string { } // Copied to agent/dns/recursor_test.go -func TestNDS_RecursorAddr(t *testing.T) { +func TestDNS_RecursorAddr(t *testing.T) { addr, err := recursorAddr("8.8.8.8") if err != nil { t.Fatalf("err: %v", err) @@ -328,14 +329,12 @@ func TestDNS_CycleRecursorCheckAllFail(t *testing.T) { } } -// TODO(v2-dns): NET-7643 - Implement EDNS0 records when queried func TestDNS_EDNS0(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } - t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, experimentsHCL) defer a.Shutdown() @@ -377,14 +376,12 @@ func TestDNS_EDNS0(t *testing.T) { } } -// TODO(v2-dns): NET-7643 - Implement EDNS0 records when queried func TestDNS_EDNS0_ECS(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } - t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, experimentsHCL) defer a.Shutdown() @@ -623,7 +620,6 @@ func TestDNS_ReverseLookup_IPV6(t *testing.T) { } } -// TODO(v2-dns): NET-7640 - NS Record not populate on some invalid service / prepared query lookups func TestDNS_SOA_Settings(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -651,7 +647,7 @@ func TestDNS_SOA_Settings(t *testing.T) { require.Equal(t, uint32(retry), soaRec.Retry) require.Equal(t, uint32(ttl), soaRec.Hdr.Ttl) } - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { // Default configuration @@ -770,8 +766,7 @@ func TestDNS_InifiniteRecursion(t *testing.T) { } // This test should not create an infinite recursion - t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, ` domain = "CONSUL." @@ -827,16 +822,12 @@ func TestDNS_InifiniteRecursion(t *testing.T) { } } -// TODO: NET-7640 - NS Record not populate on some invalid service / prepared query lookups -// this is actually an I/O timeout so it might not be the same root cause listed in NET-7640 -// but going to cover investigating it there. func TestDNS_NSRecords(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } - t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, ` domain = "CONSUL." @@ -873,14 +864,12 @@ func TestDNS_NSRecords(t *testing.T) { } } -// TODO: NET-7640 - NS Record not populate on some invalid service / prepared query lookups func TestDNS_AltDomain_NSRecords(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } - t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, ` @@ -930,14 +919,12 @@ func TestDNS_AltDomain_NSRecords(t *testing.T) { } } -// TODO: NET-7640 - NS Record not populate on some invalid service / prepared query lookups func TestDNS_NSRecords_IPV6(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } - t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, ` domain = "CONSUL." @@ -975,14 +962,12 @@ func TestDNS_NSRecords_IPV6(t *testing.T) { } } -// TODO: NET-7640 - NS Record not populate on some invalid service / prepared query lookups func TestDNS_AltDomain_NSRecords_IPV6(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } - t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, ` domain = "CONSUL." @@ -1630,7 +1615,9 @@ func TestDNS_RecursorTimeout(t *testing.T) { } } -// TODO(v2-dns): NET-7646 - account for this functionality since there is +// TODO(v2-dns): NET-7646 - account for this functionality in v1 since there is +// no way to run a v2 version of this test since it is calling a private function and not +// using a test agent. func TestDNS_BinarySearch(t *testing.T) { msgSrc := new(dns.Msg) msgSrc.Compress = true @@ -1671,14 +1658,12 @@ func TestDNS_BinarySearch(t *testing.T) { } } -// TODO(v2-dns): NET-7635 - Fix dns: overflowing header size or IO timeouts func TestDNS_TCP_and_UDP_Truncate(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } - t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, ` dns_config { @@ -2336,14 +2321,12 @@ func TestDNS_AltDomains_Service(t *testing.T) { } } -// TODO(v2-dns): NET-7640 - NS or SOA Records not populate on some invalid service / prepared query lookups func TestDNS_AltDomains_SOA(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } - t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, ` node_name = "test-node" @@ -2539,14 +2522,12 @@ func TestDNS_PreparedQuery_AllowStale(t *testing.T) { } } -// TODO (v2-dns): NET-7640 - NS or SOA Records not populate on some invalid service / prepared query lookups func TestDNS_InvalidQueries(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } - t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, experimentsHCL) defer a.Shutdown()