From 222e689ac6900f21ebf2093fa99e2ac78eae3d6c Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Tue, 25 Feb 2014 12:07:20 -0800 Subject: [PATCH] agent: DNS layer properly handles AAAA and CNAME records --- command/agent/dns.go | 113 ++++++++++++++++------------- command/agent/dns_test.go | 149 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 211 insertions(+), 51 deletions(-) diff --git a/command/agent/dns.go b/command/agent/dns.go index 70e76648cf..78b3bf26a3 100644 --- a/command/agent/dns.go +++ b/command/agent/dns.go @@ -306,27 +306,57 @@ func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns. return } + // Add the node record + record := formatNodeRecord(&out.NodeServices.Node, req.Question[0].Name, qType) + if record != nil { + resp.Answer = append(resp.Answer, record) + } +} + +// formatNodeRecord takes a Node and returns an A, AAAA, or CNAME record +func formatNodeRecord(node *structs.Node, qName string, qType uint16) dns.RR { // Parse the IP - ip := net.ParseIP(out.NodeServices.Node.Address) - if ip == nil { - d.logger.Printf("[ERR] dns: failed to parse IP %v", out.NodeServices.Node) - resp.SetRcode(req, dns.RcodeServerFailure) - return + ip := net.ParseIP(node.Address) + var ipv4 net.IP + if ip != nil { + ipv4 = ip.To4() } + switch { + case ipv4 != nil && (qType == dns.TypeANY || qType == dns.TypeA): + return &dns.A{ + Hdr: dns.RR_Header{ + Name: qName, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 0, + }, + A: ip, + } - // Format A record - aRec := &dns.A{ - Hdr: dns.RR_Header{ - Name: req.Question[0].Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 0, - }, - A: ip, + case ip != nil && ipv4 == nil && (qType == dns.TypeANY || qType == dns.TypeAAAA): + return &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: qName, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 0, + }, + AAAA: ip, + } + + case ip == nil && (qType == dns.TypeANY || qType == dns.TypeCNAME): + return &dns.CNAME{ + Hdr: dns.RR_Header{ + Name: qName, + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + Ttl: 0, + }, + Target: dns.Fqdn(node.Address), + } + default: + return nil } - - // Add the response - resp.Answer = append(resp.Answer, aRec) } // serviceLookup is used to handle a service query @@ -364,9 +394,8 @@ func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, req, // Add various responses depending on the request qType := req.Question[0].Qtype - if qType == dns.TypeANY || qType == dns.TypeA { - d.serviceARecords(out.Nodes, req, resp) - } + d.serviceNodeRecords(out.Nodes, req, resp) + if qType == dns.TypeANY || qType == dns.TypeSRV { d.serviceSRVRecords(datacenter, out.Nodes, req, resp) } @@ -399,8 +428,10 @@ func shuffleServiceNodes(nodes structs.CheckServiceNodes) { } } -// serviceARecords is used to add the A records for a service lookup -func (d *DNSServer) serviceARecords(nodes structs.CheckServiceNodes, req, resp *dns.Msg) { +// serviceNodeRecords is used to add the node records for a service lookup +func (d *DNSServer) serviceNodeRecords(nodes structs.CheckServiceNodes, req, resp *dns.Msg) { + qName := req.Question[0].Name + qType := req.Question[0].Qtype handled := make(map[string]struct{}) for _, node := range nodes { // Avoid duplicate entries, possible if a node has @@ -411,21 +442,11 @@ func (d *DNSServer) serviceARecords(nodes structs.CheckServiceNodes, req, resp * } handled[addr] = struct{}{} - ip := net.ParseIP(addr) - if ip == nil { - d.logger.Printf("[ERR] dns: failed to parse IP %v for %v", addr, node.Node) - continue + // Add the node record + record := formatNodeRecord(&node.Node, qName, qType) + if record != nil { + resp.Answer = append(resp.Answer, record) } - aRec := &dns.A{ - Hdr: dns.RR_Header{ - Name: req.Question[0].Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 0, - }, - A: ip, - } - resp.Answer = append(resp.Answer, aRec) } } @@ -456,7 +477,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes } resp.Answer = append(resp.Answer, srvRec) - // Avoid duplicate A records, possible if a node has + // Avoid duplicate extra records, possible if a node has // the same service on multiple ports, etc. addr := node.Node.Address if _, ok := handled[addr]; ok { @@ -464,21 +485,11 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes } handled[addr] = struct{}{} - ip := net.ParseIP(addr) - if ip == nil { - d.logger.Printf("[ERR] dns: failed to parse IP %v for %v", addr, node.Node) - continue + // Add the extra record + record := formatNodeRecord(&node.Node, srvRec.Target, dns.TypeANY) + if record != nil { + resp.Extra = append(resp.Extra, record) } - aRec := &dns.A{ - Hdr: dns.RR_Header{ - Name: srvRec.Target, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 0, - }, - A: ip, - } - resp.Extra = append(resp.Extra, aRec) } } diff --git a/command/agent/dns_test.go b/command/agent/dns_test.go index c2aa000ea1..f1b3a4a9fb 100644 --- a/command/agent/dns_test.go +++ b/command/agent/dns_test.go @@ -121,6 +121,88 @@ func TestDNS_NodeLookup(t *testing.T) { } } +func TestDNS_NodeLookup_AAAA(t *testing.T) { + dir, srv := makeDNSServer(t) + defer os.RemoveAll(dir) + defer srv.agent.Shutdown() + + // Wait for leader + time.Sleep(100 * time.Millisecond) + + // Register node + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "bar", + Address: "::4242:4242", + } + var out struct{} + if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + + m := new(dns.Msg) + m.SetQuestion("bar.node.consul.", dns.TypeANY) + + c := new(dns.Client) + in, _, err := c.Exchange(m, srv.agent.config.DNSAddr) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(in.Answer) != 1 { + t.Fatalf("Bad: %#v", in) + } + + aRec, ok := in.Answer[0].(*dns.AAAA) + if !ok { + t.Fatalf("Bad: %#v", in.Answer[0]) + } + if aRec.AAAA.String() != "::4242:4242" { + t.Fatalf("Bad: %#v", in.Answer[0]) + } +} + +func TestDNS_NodeLookup_CNAME(t *testing.T) { + dir, srv := makeDNSServer(t) + defer os.RemoveAll(dir) + defer srv.agent.Shutdown() + + // Wait for leader + time.Sleep(100 * time.Millisecond) + + // Register node + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "google", + Address: "www.google.com", + } + var out struct{} + if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + + m := new(dns.Msg) + m.SetQuestion("google.node.consul.", dns.TypeANY) + + c := new(dns.Client) + in, _, err := c.Exchange(m, srv.agent.config.DNSAddr) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(in.Answer) != 1 { + t.Fatalf("Bad: %#v", in) + } + + cnRec, ok := in.Answer[0].(*dns.CNAME) + if !ok { + t.Fatalf("Bad: %#v", in.Answer[0]) + } + if cnRec.Target != "www.google.com." { + t.Fatalf("Bad: %#v", in.Answer[0]) + } +} + func TestDNS_ServiceLookup(t *testing.T) { dir, srv := makeDNSServer(t) defer os.RemoveAll(dir) @@ -449,3 +531,70 @@ func TestDNS_ServiceLookup_Randomize(t *testing.T) { uniques[nameS] = struct{}{} } } + +func TestDNS_ServiceLookup_CNAME(t *testing.T) { + dir, srv := makeDNSServer(t) + defer os.RemoveAll(dir) + defer srv.agent.Shutdown() + + // Wait for leader + time.Sleep(100 * time.Millisecond) + + // Register node + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "google", + Address: "www.google.com", + Service: &structs.NodeService{ + Service: "search", + Port: 80, + }, + } + var out struct{} + if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + + m := new(dns.Msg) + m.SetQuestion("search.service.consul.", dns.TypeANY) + + c := new(dns.Client) + in, _, err := c.Exchange(m, srv.agent.config.DNSAddr) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(in.Answer) != 2 { + t.Fatalf("Bad: %#v", in) + } + + cnRec, ok := in.Answer[0].(*dns.CNAME) + if !ok { + t.Fatalf("Bad: %#v", in.Answer[0]) + } + if cnRec.Target != "www.google.com." { + t.Fatalf("Bad: %#v", in.Answer[0]) + } + + srvRec, ok := in.Answer[1].(*dns.SRV) + if !ok { + t.Fatalf("Bad: %#v", in.Answer[1]) + } + if srvRec.Port != 80 { + t.Fatalf("Bad: %#v", srvRec) + } + if srvRec.Target != "google.node.dc1.consul." { + t.Fatalf("Bad: %#v", srvRec) + } + + cnRec, ok = in.Extra[0].(*dns.CNAME) + if !ok { + t.Fatalf("Bad: %#v", in.Extra[0]) + } + if cnRec.Hdr.Name != "google.node.dc1.consul." { + t.Fatalf("Bad: %#v", in.Extra[0]) + } + if cnRec.Target != "www.google.com." { + t.Fatalf("Bad: %#v", in.Extra[0]) + } +}