From be39fb20cc51caf0d739fb5514254c979231fc3e Mon Sep 17 00:00:00 2001
From: Pierre Souchay
Date: Wed, 7 Mar 2018 10:01:12 +0100
Subject: [PATCH] [BUGFIX] do not break when TCP DNS answer exceeds 64k
It will avoid having discovery broken when having large number
of instances of a service (works with SRV and A* records).
Fixes https://github.com/hashicorp/consul/issues/3850
---
agent/dns.go | 60 ++++++++++++++++++++++----------
agent/dns_test.go | 88 +++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 130 insertions(+), 18 deletions(-)
diff --git a/agent/dns.go b/agent/dns.go
index b809a2b3a3..5011fbbbf1 100644
--- a/agent/dns.go
+++ b/agent/dns.go
@@ -713,6 +713,32 @@ func syncExtra(index map[string]dns.RR, resp *dns.Msg) {
resp.Extra = extra
}
+// trimTCPResponse limit the MaximumSize of messages to 64k as it is the limit
+// of DNS responses
+func trimTCPResponse(req, resp *dns.Msg) (trimmed bool) {
+ hasExtra := len(resp.Extra) > 0
+ maxSize := 65535
+
+ // We avoid some function calls and allocations by only handling the
+ // extra data when necessary.
+ var index map[string]dns.RR
+ if hasExtra {
+ index = make(map[string]dns.RR, len(resp.Extra))
+ indexRRs(resp.Extra, index)
+ }
+ truncated := false
+
+ // This enforces the given limit on 64k, the max limit for DNS messages
+ for len(resp.Answer) > 0 && resp.Len() > maxSize {
+ truncated = true
+ resp.Answer = resp.Answer[:len(resp.Answer)-1]
+ if hasExtra {
+ syncExtra(index, resp)
+ }
+ }
+ return truncated
+}
+
// trimUDPResponse makes sure a UDP response is not longer than allowed by RFC
// 1035. Enforce an arbitrary limit that can be further ratcheted down by
// config, and then make sure the response doesn't exceed 512 bytes. Any extra
@@ -765,6 +791,20 @@ func trimUDPResponse(req, resp *dns.Msg, udpAnswerLimit int) (trimmed bool) {
return len(resp.Answer) < numAnswers
}
+// trimDNSResponse will trim the response for UDP and TCP
+func (d *DNSServer) trimDNSResponse(network string, req, resp *dns.Msg) (trimmed bool) {
+ if network != "tcp" {
+ trimmed = trimUDPResponse(req, resp, d.config.UDPAnswerLimit)
+ } else {
+ trimmed = trimTCPResponse(req, resp)
+ }
+ // Flag that there are more records to return in the UDP response
+ if trimmed && d.config.EnableTruncate {
+ resp.Truncated = true
+ }
+ return trimmed
+}
+
// lookupServiceNodes returns nodes with a given service.
func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string) (structs.IndexedCheckServiceNodes, error) {
args := structs.ServiceSpecificRequest{
@@ -840,15 +880,7 @@ func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, req,
d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl)
}
- // If the network is not TCP, restrict the number of responses
- if network != "tcp" {
- wasTrimmed := trimUDPResponse(req, resp, d.config.UDPAnswerLimit)
-
- // Flag that there are more records to return in the UDP response
- if wasTrimmed && d.config.EnableTruncate {
- resp.Truncated = true
- }
- }
+ d.trimDNSResponse(network, req, resp)
// If the answer is empty and the response isn't truncated, return not found
if len(resp.Answer) == 0 && !resp.Truncated {
@@ -950,15 +982,7 @@ RPC:
d.serviceNodeRecords(out.Datacenter, out.Nodes, req, resp, ttl)
}
- // If the network is not TCP, restrict the number of responses.
- if network != "tcp" {
- wasTrimmed := trimUDPResponse(req, resp, d.config.UDPAnswerLimit)
-
- // Flag that there are more records to return in the UDP response
- if wasTrimmed && d.config.EnableTruncate {
- resp.Truncated = true
- }
- }
+ d.trimDNSResponse(network, req, resp)
// If the answer is empty and the response isn't truncated, return not found
if len(resp.Answer) == 0 && !resp.Truncated {
diff --git a/agent/dns_test.go b/agent/dns_test.go
index d42abbeb7a..18da89439f 100644
--- a/agent/dns_test.go
+++ b/agent/dns_test.go
@@ -2740,6 +2740,94 @@ func TestDNS_ServiceLookup_Randomize(t *testing.T) {
}
}
+func TestDNS_TCP_and_UDP_Truncate(t *testing.T) {
+ t.Parallel()
+ a := NewTestAgent(t.Name(), `
+ dns_config {
+ enable_truncate = true
+ }
+ `)
+ defer a.Shutdown()
+
+ services := []string{"normal", "truncated"}
+ for index, service := range services {
+ numServices := (index * 5000) + 2
+ for i := 1; i < numServices; i++ {
+ args := &structs.RegisterRequest{
+ Datacenter: "dc1",
+ Node: fmt.Sprintf("%s-%d.acme.com", service, i),
+ Address: fmt.Sprintf("127.%d.%d.%d", index, (i / 255), i%255),
+ Service: &structs.NodeService{
+ Service: service,
+ Port: 8000,
+ },
+ }
+
+ var out struct{}
+ if err := a.RPC("Catalog.Register", args, &out); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }
+
+ // Register an equivalent prepared query.
+ var id string
+ {
+ args := &structs.PreparedQueryRequest{
+ Datacenter: "dc1",
+ Op: structs.PreparedQueryCreate,
+ Query: &structs.PreparedQuery{
+ Name: service,
+ Service: structs.ServiceQuery{
+ Service: service,
+ },
+ },
+ }
+ if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }
+
+ // Look up the service directly and via prepared query. Ensure the
+ // response is truncated each time.
+ questions := []string{
+ fmt.Sprintf("%s.service.consul.", service),
+ id + ".query.consul.",
+ }
+ protocols := []string{
+ "tcp",
+ "udp",
+ }
+ for _, qType := range []uint16{dns.TypeANY, dns.TypeA, dns.TypeSRV} {
+ for _, question := range questions {
+ for _, protocol := range protocols {
+ t.Run(fmt.Sprintf("lookup %s %s (qType:=%d)", question, protocol, qType), func(t *testing.T) {
+ m := new(dns.Msg)
+ m.SetQuestion(question, dns.TypeANY)
+ if protocol == "udp" {
+ m.SetEdns0(8192, true)
+ }
+ c := new(dns.Client)
+ c.Net = protocol
+ in, out, err := c.Exchange(m, a.DNSAddr())
+ if err != nil && err != dns.ErrTruncated {
+ t.Fatalf("err: %v", err)
+ }
+
+ // Check for the truncate bit
+ shouldBeTruncated := numServices > 4095
+
+ if shouldBeTruncated != in.Truncated {
+ info := fmt.Sprintf("service %s question:=%s (%s) (%d total records) in %v",
+ service, question, protocol, numServices, out)
+ t.Fatalf("Should have truncate:=%v for %s", shouldBeTruncated, info)
+ }
+ })
+ }
+ }
+ }
+ }
+}
+
func TestDNS_ServiceLookup_Truncate(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), `