agent: Fixing DNS CNAME recursion

pull/19/head
Armon Dadgar 11 years ago
parent 60b7fccf02
commit 78e28a84a1

@ -16,6 +16,7 @@ const (
testQuery = "_test.consul." testQuery = "_test.consul."
consulDomain = "consul." consulDomain = "consul."
maxServiceResponses = 3 // For UDP only maxServiceResponses = 3 // For UDP only
maxRecurseRecords = 3
) )
// DNSServer is used to wrap an Agent and expose various // DNSServer is used to wrap an Agent and expose various
@ -175,7 +176,7 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(req) m.SetReply(req)
m.Authoritative = true m.Authoritative = true
m.RecursionAvailable = true m.RecursionAvailable = (d.recursor != "")
// Only add the SOA if requested // Only add the SOA if requested
if req.Question[0].Qtype == dns.TypeSOA { if req.Question[0].Qtype == dns.TypeSOA {
@ -313,22 +314,14 @@ func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.
} }
// Add the node record // Add the node record
record := formatNodeRecord(&out.NodeServices.Node, req.Question[0].Name, qType) records := d.formatNodeRecord(&out.NodeServices.Node, req.Question[0].Name, qType)
if record != nil { if records != nil {
resp.Answer = append(resp.Answer, record) resp.Answer = append(resp.Answer, records...)
// Try to recursively resolve the CNAME
if cnRec, ok := record.(*dns.CNAME); ok {
aRecs := d.resolveCNAME(cnRec.Target)
if len(aRecs) > 0 {
resp.Extra = append(resp.Extra, aRecs[0])
}
}
} }
} }
// formatNodeRecord takes a Node and returns an A, AAAA, or CNAME record // formatNodeRecord takes a Node and returns an A, AAAA, or CNAME record
func formatNodeRecord(node *structs.Node, qName string, qType uint16) dns.RR { func (d *DNSServer) formatNodeRecord(node *structs.Node, qName string, qType uint16) (records []dns.RR) {
// Parse the IP // Parse the IP
ip := net.ParseIP(node.Address) ip := net.ParseIP(node.Address)
var ipv4 net.IP var ipv4 net.IP
@ -337,7 +330,7 @@ func formatNodeRecord(node *structs.Node, qName string, qType uint16) dns.RR {
} }
switch { switch {
case ipv4 != nil && (qType == dns.TypeANY || qType == dns.TypeA): case ipv4 != nil && (qType == dns.TypeANY || qType == dns.TypeA):
return &dns.A{ return []dns.RR{&dns.A{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: qName, Name: qName,
Rrtype: dns.TypeA, Rrtype: dns.TypeA,
@ -345,10 +338,10 @@ func formatNodeRecord(node *structs.Node, qName string, qType uint16) dns.RR {
Ttl: 0, Ttl: 0,
}, },
A: ip, A: ip,
} }}
case ip != nil && ipv4 == nil && (qType == dns.TypeANY || qType == dns.TypeAAAA): case ip != nil && ipv4 == nil && (qType == dns.TypeANY || qType == dns.TypeAAAA):
return &dns.AAAA{ return []dns.RR{&dns.AAAA{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: qName, Name: qName,
Rrtype: dns.TypeAAAA, Rrtype: dns.TypeAAAA,
@ -356,10 +349,12 @@ func formatNodeRecord(node *structs.Node, qName string, qType uint16) dns.RR {
Ttl: 0, Ttl: 0,
}, },
AAAA: ip, AAAA: ip,
} }}
case ip == nil && (qType == dns.TypeANY || qType == dns.TypeCNAME): case ip == nil && (qType == dns.TypeANY || qType == dns.TypeCNAME ||
return &dns.CNAME{ qType == dns.TypeA || qType == dns.TypeAAAA):
// Get the CNAME
cnRec := &dns.CNAME{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: qName, Name: qName,
Rrtype: dns.TypeCNAME, Rrtype: dns.TypeCNAME,
@ -368,9 +363,26 @@ func formatNodeRecord(node *structs.Node, qName string, qType uint16) dns.RR {
}, },
Target: dns.Fqdn(node.Address), Target: dns.Fqdn(node.Address),
} }
default: records = append(records, cnRec)
return nil
// Recurse
more := d.resolveCNAME(cnRec.Target)
extra := 0
MORE_REC:
for _, rr := range more {
switch rr.Header().Rrtype {
case dns.TypeA:
fallthrough
case dns.TypeAAAA:
records = append(records, rr)
extra++
if extra == maxRecurseRecords {
break MORE_REC
}
}
}
} }
return records
} }
// serviceLookup is used to handle a service query // serviceLookup is used to handle a service query
@ -410,12 +422,9 @@ func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, req,
qType := req.Question[0].Qtype qType := req.Question[0].Qtype
d.serviceNodeRecords(out.Nodes, req, resp) d.serviceNodeRecords(out.Nodes, req, resp)
if qType == dns.TypeANY || qType == dns.TypeSRV { if qType == dns.TypeSRV {
d.serviceSRVRecords(datacenter, out.Nodes, req, resp) d.serviceSRVRecords(datacenter, out.Nodes, req, resp)
} }
// Cleanup duplicate extra entries
resp.Extra = removeDuplicates(resp.Extra)
} }
// filterServiceNodes is used to filter out nodes that are failing // filterServiceNodes is used to filter out nodes that are failing
@ -460,17 +469,9 @@ func (d *DNSServer) serviceNodeRecords(nodes structs.CheckServiceNodes, req, res
handled[addr] = struct{}{} handled[addr] = struct{}{}
// Add the node record // Add the node record
record := formatNodeRecord(&node.Node, qName, qType) records := d.formatNodeRecord(&node.Node, qName, qType)
if record != nil { if records != nil {
resp.Answer = append(resp.Answer, record) resp.Answer = append(resp.Answer, records...)
// Try to recursively resolve the CNAME
if cnRec, ok := record.(*dns.CNAME); ok {
aRecs := d.resolveCNAME(cnRec.Target)
if len(aRecs) > 0 {
resp.Extra = append(resp.Extra, aRecs[0])
}
}
} }
} }
} }
@ -502,26 +503,10 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes
} }
resp.Answer = append(resp.Answer, srvRec) resp.Answer = append(resp.Answer, srvRec)
// Avoid duplicate extra records, possible if a node has
// the same service on multiple ports, etc.
addr := node.Node.Address
if _, ok := handled[addr]; ok {
continue
}
handled[addr] = struct{}{}
// Add the extra record // Add the extra record
record := formatNodeRecord(&node.Node, srvRec.Target, dns.TypeANY) records := d.formatNodeRecord(&node.Node, srvRec.Target, dns.TypeANY)
if record != nil { if records != nil {
resp.Extra = append(resp.Extra, record) resp.Extra = append(resp.Extra, records...)
// Try to recursively resolve the CNAME
if cnRec, ok := record.(*dns.CNAME); ok {
aRecs := d.resolveCNAME(cnRec.Target)
if len(aRecs) > 0 {
resp.Extra = append(resp.Extra, aRecs[0])
}
}
} }
} }
} }
@ -584,23 +569,3 @@ func (d *DNSServer) resolveCNAME(name string) []dns.RR {
// Return all the answers // Return all the answers
return r.Answer return r.Answer
} }
// removeDuplicates is used to remove the duplicate entries.
// This only deduplicates on the QName and QType
func removeDuplicates(rr []dns.RR) []dns.RR {
handled := make(map[string]struct{})
n := len(rr)
for i := 0; i < n; i++ {
rec := rr[i]
hdr := rec.Header()
key := fmt.Sprintf("%s:%d", hdr.Name, hdr.Rrtype)
if _, ok := handled[key]; ok {
// Remove duplicate
rr[i], rr[n-1] = rr[n-1], nil
n--
i--
}
handled[key] = struct{}{}
}
return rr[:n]
}

@ -190,7 +190,8 @@ func TestDNS_NodeLookup_CNAME(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if len(in.Answer) != 1 { // Should have the CNAME record + a few A records
if len(in.Answer) < 2 {
t.Fatalf("Bad: %#v", in) t.Fatalf("Bad: %#v", in)
} }
@ -228,7 +229,7 @@ func TestDNS_ServiceLookup(t *testing.T) {
} }
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion("db.service.consul.", dns.TypeANY) m.SetQuestion("db.service.consul.", dns.TypeSRV)
c := new(dns.Client) c := new(dns.Client)
in, _, err := c.Exchange(m, srv.agent.config.DNSAddr) in, _, err := c.Exchange(m, srv.agent.config.DNSAddr)
@ -236,22 +237,14 @@ func TestDNS_ServiceLookup(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if len(in.Answer) != 2 { if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in) t.Fatalf("Bad: %#v", in)
} }
aRec, ok := in.Answer[0].(*dns.A) srvRec, ok := in.Answer[0].(*dns.SRV)
if !ok { if !ok {
t.Fatalf("Bad: %#v", in.Answer[0]) t.Fatalf("Bad: %#v", in.Answer[0])
} }
if aRec.A.String() != "127.0.0.1" {
t.Fatalf("Bad: %#v", in.Answer[0])
}
srvRec, ok := in.Answer[1].(*dns.SRV)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[1])
}
if srvRec.Port != 12345 { if srvRec.Port != 12345 {
t.Fatalf("Bad: %#v", srvRec) t.Fatalf("Bad: %#v", srvRec)
} }
@ -259,7 +252,7 @@ func TestDNS_ServiceLookup(t *testing.T) {
t.Fatalf("Bad: %#v", srvRec) t.Fatalf("Bad: %#v", srvRec)
} }
aRec, ok = in.Extra[0].(*dns.A) aRec, ok := in.Extra[0].(*dns.A)
if !ok { if !ok {
t.Fatalf("Bad: %#v", in.Extra[0]) t.Fatalf("Bad: %#v", in.Extra[0])
} }
@ -334,7 +327,7 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if len(in.Answer) != 3 { if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in) t.Fatalf("Bad: %#v", in)
} }
@ -345,10 +338,78 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) {
if aRec.A.String() != "127.0.0.1" { if aRec.A.String() != "127.0.0.1" {
t.Fatalf("Bad: %#v", in.Answer[0]) t.Fatalf("Bad: %#v", in.Answer[0])
} }
}
func TestDNS_ServiceLookup_Dedup_SRV(t *testing.T) {
dir, srv := makeDNSServer(t)
defer os.RemoveAll(dir)
defer srv.agent.Shutdown()
// Wait for leader
time.Sleep(100 * time.Millisecond)
// Register node
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo",
Address: "127.0.0.1",
Service: &structs.NodeService{
Service: "db",
Tag: "master",
Port: 12345,
},
}
var out struct{}
if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
args = &structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo",
Address: "127.0.0.1",
Service: &structs.NodeService{
ID: "db2",
Service: "db",
Tag: "slave",
Port: 12345,
},
}
if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
args = &structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo",
Address: "127.0.0.1",
Service: &structs.NodeService{
ID: "db3",
Service: "db",
Tag: "slave",
Port: 12346,
},
}
if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
m := new(dns.Msg)
m.SetQuestion("db.service.consul.", dns.TypeSRV)
c := new(dns.Client)
in, _, err := c.Exchange(m, srv.agent.config.DNSAddr)
if err != nil {
t.Fatalf("err: %v", err)
}
if len(in.Answer) != 2 {
t.Fatalf("Bad: %#v", in)
}
srvRec, ok := in.Answer[1].(*dns.SRV) srvRec, ok := in.Answer[0].(*dns.SRV)
if !ok { if !ok {
t.Fatalf("Bad: %#v", in.Answer[1]) t.Fatalf("Bad: %#v", in.Answer[0])
} }
if srvRec.Port != 12345 && srvRec.Port != 12346 { if srvRec.Port != 12345 && srvRec.Port != 12346 {
t.Fatalf("Bad: %#v", srvRec) t.Fatalf("Bad: %#v", srvRec)
@ -357,21 +418,21 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) {
t.Fatalf("Bad: %#v", srvRec) t.Fatalf("Bad: %#v", srvRec)
} }
srvRec, ok = in.Answer[2].(*dns.SRV) srvRec, ok = in.Answer[1].(*dns.SRV)
if !ok { if !ok {
t.Fatalf("Bad: %#v", in.Answer[1]) t.Fatalf("Bad: %#v", in.Answer[1])
} }
if srvRec.Port != 12346 && srvRec.Port != 12345 { if srvRec.Port != 12346 && srvRec.Port != 12345 {
t.Fatalf("Bad: %#v", srvRec) t.Fatalf("Bad: %#v", srvRec)
} }
if srvRec.Port == in.Answer[1].(*dns.SRV).Port { if srvRec.Port == in.Answer[0].(*dns.SRV).Port {
t.Fatalf("should be a different port") t.Fatalf("should be a different port")
} }
if srvRec.Target != "foo.node.dc1.consul." { if srvRec.Target != "foo.node.dc1.consul." {
t.Fatalf("Bad: %#v", srvRec) t.Fatalf("Bad: %#v", srvRec)
} }
aRec, ok = in.Extra[0].(*dns.A) aRec, ok := in.Extra[0].(*dns.A)
if !ok { if !ok {
t.Fatalf("Bad: %#v", in.Extra[0]) t.Fatalf("Bad: %#v", in.Extra[0])
} }
@ -507,8 +568,8 @@ func TestDNS_ServiceLookup_Randomize(t *testing.T) {
} }
// Response length should be truncated // Response length should be truncated
// We should get an SRV + A record for each response (hence 2x) // We should get an A record for each response
if len(in.Answer) != 2*maxServiceResponses { if len(in.Answer) != maxServiceResponses {
t.Fatalf("Bad: %#v", len(in.Answer)) t.Fatalf("Bad: %#v", len(in.Answer))
} }
@ -564,10 +625,11 @@ func TestDNS_ServiceLookup_CNAME(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if len(in.Answer) != 2 { if len(in.Answer) != 4 {
t.Fatalf("Bad: %#v", in) t.Fatalf("Bad: %#v", in)
} }
// Should have google CNAME
cnRec, ok := in.Answer[0].(*dns.CNAME) cnRec, ok := in.Answer[0].(*dns.CNAME)
if !ok { if !ok {
t.Fatalf("Bad: %#v", in.Answer[0]) t.Fatalf("Bad: %#v", in.Answer[0])
@ -576,33 +638,10 @@ func TestDNS_ServiceLookup_CNAME(t *testing.T) {
t.Fatalf("Bad: %#v", in.Answer[0]) t.Fatalf("Bad: %#v", in.Answer[0])
} }
srvRec, ok := in.Answer[1].(*dns.SRV) // Check we recursively resolve
if !ok { for i := 1; i < 4; i++ {
t.Fatalf("Bad: %#v", in.Answer[1]) if _, ok := in.Answer[i].(*dns.A); !ok {
} t.Fatalf("Bad: %#v", in.Answer[i])
if srvRec.Port != 80 { }
t.Fatalf("Bad: %#v", srvRec)
}
if srvRec.Target != "google.node.dc1.consul." {
t.Fatalf("Bad: %#v", srvRec)
}
aRec, ok := in.Extra[0].(*dns.A)
if !ok {
t.Fatalf("Bad: %#v", in.Extra[0])
}
if aRec.Hdr.Name != "www.google.com." {
t.Fatalf("Bad: %#v", in.Extra[0])
}
cnRec, ok = in.Extra[1].(*dns.CNAME)
if !ok {
t.Fatalf("Bad: %#v", in.Extra[1])
}
if cnRec.Hdr.Name != "google.node.dc1.consul." {
t.Fatalf("Bad: %#v", in.Extra[1])
}
if cnRec.Target != "www.google.com." {
t.Fatalf("Bad: %#v", in.Extra[1])
} }
} }

Loading…
Cancel
Save