From 7d59249d96e5e2d24c423024174d27618f7ad478 Mon Sep 17 00:00:00 2001
From: Pierre Souchay
Date: Wed, 7 Mar 2018 23:33:41 +0100
Subject: [PATCH] Avoid issue with compression of DNS messages causing overflow
---
agent/dns.go | 7 ++++++-
agent/dns_test.go | 47 +++++++++++++++++++++++++----------------------
2 files changed, 31 insertions(+), 23 deletions(-)
diff --git a/agent/dns.go b/agent/dns.go
index 21e9e67140..2750ed6b02 100644
--- a/agent/dns.go
+++ b/agent/dns.go
@@ -718,7 +718,10 @@ func syncExtra(index map[string]dns.RR, resp *dns.Msg) {
func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) {
hasExtra := len(resp.Extra) > 0
// There is some overhead, 65535 does not work
- maxSize := 64000
+ maxSize := 65533 // 64k - 2 bytes
+ // In order to compute properly, we have to avoid compress first
+ compressed := resp.Compress
+ resp.Compress = false
// We avoid some function calls and allocations by only handling the
// extra data when necessary.
@@ -745,6 +748,8 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) {
len(resp.Answer), originalNumRecords, resp.Len(), originalSize)
}
+ // Restore compression if any
+ resp.Compress = compressed
return truncated
}
diff --git a/agent/dns_test.go b/agent/dns_test.go
index 18da89439f..cf9571de05 100644
--- a/agent/dns_test.go
+++ b/agent/dns_test.go
@@ -2800,28 +2800,31 @@ func TestDNS_TCP_and_UDP_Truncate(t *testing.T) {
for _, qType := range []uint16{dns.TypeANY, dns.TypeA, dns.TypeSRV} {
for _, question := range questions {
for _, protocol := range protocols {
- t.Run(fmt.Sprintf("lookup %s %s (qType:=%d)", question, protocol, qType), func(t *testing.T) {
- m := new(dns.Msg)
- m.SetQuestion(question, dns.TypeANY)
- if protocol == "udp" {
- m.SetEdns0(8192, true)
- }
- c := new(dns.Client)
- c.Net = protocol
- in, out, err := c.Exchange(m, a.DNSAddr())
- if err != nil && err != dns.ErrTruncated {
- t.Fatalf("err: %v", err)
- }
-
- // Check for the truncate bit
- shouldBeTruncated := numServices > 4095
-
- if shouldBeTruncated != in.Truncated {
- info := fmt.Sprintf("service %s question:=%s (%s) (%d total records) in %v",
- service, question, protocol, numServices, out)
- t.Fatalf("Should have truncate:=%v for %s", shouldBeTruncated, info)
- }
- })
+ for _, compress := range []bool{true, false} {
+ t.Run(fmt.Sprintf("lookup %s %s (qType:=%d) compressed=%b", question, protocol, qType, compress), func(t *testing.T) {
+ m := new(dns.Msg)
+ m.SetQuestion(question, dns.TypeANY)
+ if protocol == "udp" {
+ m.SetEdns0(8192, true)
+ }
+ c := new(dns.Client)
+ c.Net = protocol
+ m.Compress = compress
+ in, out, err := c.Exchange(m, a.DNSAddr())
+ if err != nil && err != dns.ErrTruncated {
+ t.Fatalf("err: %v", err)
+ }
+
+ // Check for the truncate bit
+ shouldBeTruncated := numServices > 4095
+
+ if shouldBeTruncated != in.Truncated {
+ info := fmt.Sprintf("service %s question:=%s (%s) (%d total records) in %v",
+ service, question, protocol, numServices, out)
+ t.Fatalf("Should have truncate:=%v for %s", shouldBeTruncated, info)
+ }
+ })
+ }
}
}
}