diff --git a/command/agent/dns.go b/command/agent/dns.go index d252774e53..6ef5b6cc63 100644 --- a/command/agent/dns.go +++ b/command/agent/dns.go @@ -16,6 +16,7 @@ const ( testQuery = "_test.consul." consulDomain = "consul." maxServiceResponses = 3 // For UDP only + maxRecurseRecords = 3 ) // DNSServer is used to wrap an Agent and expose various @@ -175,7 +176,7 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { m := new(dns.Msg) m.SetReply(req) m.Authoritative = true - m.RecursionAvailable = true + m.RecursionAvailable = (d.recursor != "") // Only add the SOA if requested if req.Question[0].Qtype == dns.TypeSOA { @@ -313,22 +314,14 @@ func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns. } // Add the node record - record := formatNodeRecord(&out.NodeServices.Node, req.Question[0].Name, qType) - if record != nil { - resp.Answer = append(resp.Answer, record) - - // Try to recursively resolve the CNAME - if cnRec, ok := record.(*dns.CNAME); ok { - aRecs := d.resolveCNAME(cnRec.Target) - if len(aRecs) > 0 { - resp.Extra = append(resp.Extra, aRecs[0]) - } - } + records := d.formatNodeRecord(&out.NodeServices.Node, req.Question[0].Name, qType) + if records != nil { + resp.Answer = append(resp.Answer, records...) } } // formatNodeRecord takes a Node and returns an A, AAAA, or CNAME record -func formatNodeRecord(node *structs.Node, qName string, qType uint16) dns.RR { +func (d *DNSServer) formatNodeRecord(node *structs.Node, qName string, qType uint16) (records []dns.RR) { // Parse the IP ip := net.ParseIP(node.Address) var ipv4 net.IP @@ -337,7 +330,7 @@ func formatNodeRecord(node *structs.Node, qName string, qType uint16) dns.RR { } switch { case ipv4 != nil && (qType == dns.TypeANY || qType == dns.TypeA): - return &dns.A{ + return []dns.RR{&dns.A{ Hdr: dns.RR_Header{ Name: qName, Rrtype: dns.TypeA, @@ -345,10 +338,10 @@ func formatNodeRecord(node *structs.Node, qName string, qType uint16) dns.RR { Ttl: 0, }, A: ip, - } + }} case ip != nil && ipv4 == nil && (qType == dns.TypeANY || qType == dns.TypeAAAA): - return &dns.AAAA{ + return []dns.RR{&dns.AAAA{ Hdr: dns.RR_Header{ Name: qName, Rrtype: dns.TypeAAAA, @@ -356,10 +349,12 @@ func formatNodeRecord(node *structs.Node, qName string, qType uint16) dns.RR { Ttl: 0, }, AAAA: ip, - } + }} - case ip == nil && (qType == dns.TypeANY || qType == dns.TypeCNAME): - return &dns.CNAME{ + case ip == nil && (qType == dns.TypeANY || qType == dns.TypeCNAME || + qType == dns.TypeA || qType == dns.TypeAAAA): + // Get the CNAME + cnRec := &dns.CNAME{ Hdr: dns.RR_Header{ Name: qName, Rrtype: dns.TypeCNAME, @@ -368,9 +363,26 @@ func formatNodeRecord(node *structs.Node, qName string, qType uint16) dns.RR { }, Target: dns.Fqdn(node.Address), } - default: - return nil + records = append(records, cnRec) + + // Recurse + more := d.resolveCNAME(cnRec.Target) + extra := 0 + MORE_REC: + for _, rr := range more { + switch rr.Header().Rrtype { + case dns.TypeA: + fallthrough + case dns.TypeAAAA: + records = append(records, rr) + extra++ + if extra == maxRecurseRecords { + break MORE_REC + } + } + } } + return records } // serviceLookup is used to handle a service query @@ -410,12 +422,9 @@ func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, req, qType := req.Question[0].Qtype d.serviceNodeRecords(out.Nodes, req, resp) - if qType == dns.TypeANY || qType == dns.TypeSRV { + if qType == dns.TypeSRV { d.serviceSRVRecords(datacenter, out.Nodes, req, resp) } - - // Cleanup duplicate extra entries - resp.Extra = removeDuplicates(resp.Extra) } // filterServiceNodes is used to filter out nodes that are failing @@ -460,17 +469,9 @@ func (d *DNSServer) serviceNodeRecords(nodes structs.CheckServiceNodes, req, res handled[addr] = struct{}{} // Add the node record - record := formatNodeRecord(&node.Node, qName, qType) - if record != nil { - resp.Answer = append(resp.Answer, record) - - // Try to recursively resolve the CNAME - if cnRec, ok := record.(*dns.CNAME); ok { - aRecs := d.resolveCNAME(cnRec.Target) - if len(aRecs) > 0 { - resp.Extra = append(resp.Extra, aRecs[0]) - } - } + records := d.formatNodeRecord(&node.Node, qName, qType) + if records != nil { + resp.Answer = append(resp.Answer, records...) } } } @@ -502,26 +503,10 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes } resp.Answer = append(resp.Answer, srvRec) - // Avoid duplicate extra records, possible if a node has - // the same service on multiple ports, etc. - addr := node.Node.Address - if _, ok := handled[addr]; ok { - continue - } - handled[addr] = struct{}{} - // Add the extra record - record := formatNodeRecord(&node.Node, srvRec.Target, dns.TypeANY) - if record != nil { - resp.Extra = append(resp.Extra, record) - - // Try to recursively resolve the CNAME - if cnRec, ok := record.(*dns.CNAME); ok { - aRecs := d.resolveCNAME(cnRec.Target) - if len(aRecs) > 0 { - resp.Extra = append(resp.Extra, aRecs[0]) - } - } + records := d.formatNodeRecord(&node.Node, srvRec.Target, dns.TypeANY) + if records != nil { + resp.Extra = append(resp.Extra, records...) } } } @@ -584,23 +569,3 @@ func (d *DNSServer) resolveCNAME(name string) []dns.RR { // Return all the answers return r.Answer } - -// removeDuplicates is used to remove the duplicate entries. -// This only deduplicates on the QName and QType -func removeDuplicates(rr []dns.RR) []dns.RR { - handled := make(map[string]struct{}) - n := len(rr) - for i := 0; i < n; i++ { - rec := rr[i] - hdr := rec.Header() - key := fmt.Sprintf("%s:%d", hdr.Name, hdr.Rrtype) - if _, ok := handled[key]; ok { - // Remove duplicate - rr[i], rr[n-1] = rr[n-1], nil - n-- - i-- - } - handled[key] = struct{}{} - } - return rr[:n] -} diff --git a/command/agent/dns_test.go b/command/agent/dns_test.go index bc1b400687..68592923d1 100644 --- a/command/agent/dns_test.go +++ b/command/agent/dns_test.go @@ -190,7 +190,8 @@ func TestDNS_NodeLookup_CNAME(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Answer) != 1 { + // Should have the CNAME record + a few A records + if len(in.Answer) < 2 { t.Fatalf("Bad: %#v", in) } @@ -228,7 +229,7 @@ func TestDNS_ServiceLookup(t *testing.T) { } m := new(dns.Msg) - m.SetQuestion("db.service.consul.", dns.TypeANY) + m.SetQuestion("db.service.consul.", dns.TypeSRV) c := new(dns.Client) in, _, err := c.Exchange(m, srv.agent.config.DNSAddr) @@ -236,22 +237,14 @@ func TestDNS_ServiceLookup(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Answer) != 2 { + if len(in.Answer) != 1 { t.Fatalf("Bad: %#v", in) } - aRec, ok := in.Answer[0].(*dns.A) + srvRec, ok := in.Answer[0].(*dns.SRV) if !ok { t.Fatalf("Bad: %#v", in.Answer[0]) } - if aRec.A.String() != "127.0.0.1" { - t.Fatalf("Bad: %#v", in.Answer[0]) - } - - srvRec, ok := in.Answer[1].(*dns.SRV) - if !ok { - t.Fatalf("Bad: %#v", in.Answer[1]) - } if srvRec.Port != 12345 { t.Fatalf("Bad: %#v", srvRec) } @@ -259,7 +252,7 @@ func TestDNS_ServiceLookup(t *testing.T) { t.Fatalf("Bad: %#v", srvRec) } - aRec, ok = in.Extra[0].(*dns.A) + aRec, ok := in.Extra[0].(*dns.A) if !ok { t.Fatalf("Bad: %#v", in.Extra[0]) } @@ -334,7 +327,7 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Answer) != 3 { + if len(in.Answer) != 1 { t.Fatalf("Bad: %#v", in) } @@ -345,10 +338,78 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) { if aRec.A.String() != "127.0.0.1" { t.Fatalf("Bad: %#v", in.Answer[0]) } +} + +func TestDNS_ServiceLookup_Dedup_SRV(t *testing.T) { + dir, srv := makeDNSServer(t) + defer os.RemoveAll(dir) + defer srv.agent.Shutdown() + + // Wait for leader + time.Sleep(100 * time.Millisecond) + + // Register node + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + Service: &structs.NodeService{ + Service: "db", + Tag: "master", + Port: 12345, + }, + } + var out struct{} + if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + + args = &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + Service: &structs.NodeService{ + ID: "db2", + Service: "db", + Tag: "slave", + Port: 12345, + }, + } + if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + + args = &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + Service: &structs.NodeService{ + ID: "db3", + Service: "db", + Tag: "slave", + Port: 12346, + }, + } + if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + + m := new(dns.Msg) + m.SetQuestion("db.service.consul.", dns.TypeSRV) + + c := new(dns.Client) + in, _, err := c.Exchange(m, srv.agent.config.DNSAddr) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(in.Answer) != 2 { + t.Fatalf("Bad: %#v", in) + } - srvRec, ok := in.Answer[1].(*dns.SRV) + srvRec, ok := in.Answer[0].(*dns.SRV) if !ok { - t.Fatalf("Bad: %#v", in.Answer[1]) + t.Fatalf("Bad: %#v", in.Answer[0]) } if srvRec.Port != 12345 && srvRec.Port != 12346 { t.Fatalf("Bad: %#v", srvRec) @@ -357,21 +418,21 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) { t.Fatalf("Bad: %#v", srvRec) } - srvRec, ok = in.Answer[2].(*dns.SRV) + srvRec, ok = in.Answer[1].(*dns.SRV) if !ok { t.Fatalf("Bad: %#v", in.Answer[1]) } if srvRec.Port != 12346 && srvRec.Port != 12345 { t.Fatalf("Bad: %#v", srvRec) } - if srvRec.Port == in.Answer[1].(*dns.SRV).Port { + if srvRec.Port == in.Answer[0].(*dns.SRV).Port { t.Fatalf("should be a different port") } if srvRec.Target != "foo.node.dc1.consul." { t.Fatalf("Bad: %#v", srvRec) } - aRec, ok = in.Extra[0].(*dns.A) + aRec, ok := in.Extra[0].(*dns.A) if !ok { t.Fatalf("Bad: %#v", in.Extra[0]) } @@ -507,8 +568,8 @@ func TestDNS_ServiceLookup_Randomize(t *testing.T) { } // Response length should be truncated - // We should get an SRV + A record for each response (hence 2x) - if len(in.Answer) != 2*maxServiceResponses { + // We should get an A record for each response + if len(in.Answer) != maxServiceResponses { t.Fatalf("Bad: %#v", len(in.Answer)) } @@ -564,10 +625,11 @@ func TestDNS_ServiceLookup_CNAME(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Answer) != 2 { + if len(in.Answer) != 4 { t.Fatalf("Bad: %#v", in) } + // Should have google CNAME cnRec, ok := in.Answer[0].(*dns.CNAME) if !ok { t.Fatalf("Bad: %#v", in.Answer[0]) @@ -576,33 +638,10 @@ func TestDNS_ServiceLookup_CNAME(t *testing.T) { t.Fatalf("Bad: %#v", in.Answer[0]) } - srvRec, ok := in.Answer[1].(*dns.SRV) - if !ok { - t.Fatalf("Bad: %#v", in.Answer[1]) - } - if srvRec.Port != 80 { - t.Fatalf("Bad: %#v", srvRec) - } - if srvRec.Target != "google.node.dc1.consul." { - t.Fatalf("Bad: %#v", srvRec) - } - - aRec, ok := in.Extra[0].(*dns.A) - if !ok { - t.Fatalf("Bad: %#v", in.Extra[0]) - } - if aRec.Hdr.Name != "www.google.com." { - t.Fatalf("Bad: %#v", in.Extra[0]) - } - - cnRec, ok = in.Extra[1].(*dns.CNAME) - if !ok { - t.Fatalf("Bad: %#v", in.Extra[1]) - } - if cnRec.Hdr.Name != "google.node.dc1.consul." { - t.Fatalf("Bad: %#v", in.Extra[1]) - } - if cnRec.Target != "www.google.com." { - t.Fatalf("Bad: %#v", in.Extra[1]) + // Check we recursively resolve + for i := 1; i < 4; i++ { + if _, ok := in.Answer[i].(*dns.A); !ok { + t.Fatalf("Bad: %#v", in.Answer[i]) + } } }