diff --git a/agent/dns.go b/agent/dns.go index db0e2f6603..47357186d3 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -416,14 +416,14 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { switch req.Question[0].Qtype { case dns.TypeSOA: - ns, glue := d.nameservers(cfg, req.IsEdns0() != nil, maxRecursionLevelDefault) + ns, glue := d.nameservers(cfg, req.IsEdns0() != nil, maxRecursionLevelDefault, req) m.Answer = append(m.Answer, d.soa(cfg)) m.Ns = append(m.Ns, ns...) m.Extra = append(m.Extra, glue...) m.SetRcode(req, dns.RcodeSuccess) case dns.TypeNS: - ns, glue := d.nameservers(cfg, req.IsEdns0() != nil, maxRecursionLevelDefault) + ns, glue := d.nameservers(cfg, req.IsEdns0() != nil, maxRecursionLevelDefault, req) m.Answer = ns m.Extra = glue m.SetRcode(req, dns.RcodeSuccess) @@ -469,7 +469,7 @@ func (d *DNSServer) addSOA(cfg *dnsConfig, msg *dns.Msg) { // nameservers returns the names and ip addresses of up to three random servers // in the current cluster which serve as authoritative name servers for zone. -func (d *DNSServer) nameservers(cfg *dnsConfig, edns bool, maxRecursionLevel int) (ns []dns.RR, extra []dns.RR) { +func (d *DNSServer) nameservers(cfg *dnsConfig, edns bool, maxRecursionLevel int, req *dns.Msg) (ns []dns.RR, extra []dns.RR) { out, err := d.lookupServiceNodes(cfg, d.agent.config.Datacenter, structs.ConsulServiceName, "", false, maxRecursionLevel) if err != nil { d.logger.Printf("[WARN] dns: Unable to get list of servers: %s", err) @@ -485,7 +485,7 @@ func (d *DNSServer) nameservers(cfg *dnsConfig, edns bool, maxRecursionLevel int out.Nodes.Shuffle() for _, o := range out.Nodes { - name, addr, dc := o.Node.Node, o.Node.Address, o.Node.Datacenter + name, dc := o.Node.Node, o.Node.Datacenter if InvalidDnsRe.MatchString(name) { d.logger.Printf("[WARN] dns: Skipping invalid node %q for NS records", name) @@ -507,11 +507,7 @@ func (d *DNSServer) nameservers(cfg *dnsConfig, edns bool, maxRecursionLevel int } ns = append(ns, nsrr) - glue, meta := d.formatNodeRecord(cfg, nil, addr, fqdn, dns.TypeANY, cfg.NodeTTL, edns, maxRecursionLevel, cfg.NodeMetaTXT) - extra = append(extra, glue...) - if meta != nil && cfg.NodeMetaTXT { - extra = append(extra, meta...) - } + extra = append(extra, d.makeRecordFromNode(dc, o.Node, dns.TypeANY, fqdn, cfg.NodeTTL, maxRecursionLevel)...) // don't provide more than 3 servers if len(ns) >= 3 { @@ -728,34 +724,31 @@ func (d *DNSServer) nodeLookup(cfg *dnsConfig, network, datacenter, node string, return } - // If we have no address, return not found! + // If we have no out.NodeServices.Nodeaddress, return not found! if out.NodeServices == nil { d.addSOA(cfg, resp) resp.SetRcode(req, dns.RcodeNameError) return } - generateMeta := false - metaInAnswer := false - if qType == dns.TypeANY || qType == dns.TypeTXT { - generateMeta = true - metaInAnswer = true - } else if cfg.NodeMetaTXT { - generateMeta = true - } - // Add the node record n := out.NodeServices.Node - edns := req.IsEdns0() != nil - addr := d.agent.TranslateAddress(datacenter, n.Address, n.TaggedAddresses) - records, meta := d.formatNodeRecord(cfg, out.NodeServices.Node, addr, req.Question[0].Name, qType, cfg.NodeTTL, edns, maxRecursionLevel, generateMeta) - if records != nil { + + metaTarget := &resp.Extra + if qType == dns.TypeTXT || qType == dns.TypeANY { + metaTarget = &resp.Answer + } + + q := req.Question[0] + // Only compute A and CNAME record if query is not TXT type + if qType != dns.TypeTXT { + records := d.makeRecordFromNode(n.Datacenter, n, q.Qtype, q.Name, cfg.NodeTTL, maxRecursionLevel) resp.Answer = append(resp.Answer, records...) } - if meta != nil && metaInAnswer && generateMeta { - resp.Answer = append(resp.Answer, meta...) - } else if meta != nil && cfg.NodeMetaTXT { - resp.Extra = append(resp.Extra, meta...) + + if cfg.NodeMetaTXT || qType == dns.TypeTXT || qType == dns.TypeANY { + metas := d.generateMeta(n.Datacenter, q.Name, n, cfg.NodeTTL) + *metaTarget = append(*metaTarget, metas...) } } @@ -817,94 +810,6 @@ func encodeKVasRFC1464(key, value string) (txt string) { return key + "=" + value } -// formatNodeRecord takes a Node and returns the RRs associated with that node -// -// The return value is two slices. The first slice is the main answer slice (containing the A, AAAA, CNAME) RRs for the node -// and the second slice contains any TXT RRs created from the node metadata. It is up to the caller to determine where the -// generated RRs should go and if they should be used at all. -func (d *DNSServer) formatNodeRecord(cfg *dnsConfig, node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool, maxRecursionLevel int, generateMeta bool) (records, meta []dns.RR) { - // Parse the IP - ip := net.ParseIP(addr) - var ipv4 net.IP - if ip != nil { - ipv4 = ip.To4() - } - - switch { - case ipv4 != nil && (qType == dns.TypeANY || qType == dns.TypeA): - records = append(records, &dns.A{ - Hdr: dns.RR_Header{ - Name: qName, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: uint32(ttl / time.Second), - }, - A: ip, - }) - - case ip != nil && ipv4 == nil && (qType == dns.TypeANY || qType == dns.TypeAAAA): - records = append(records, &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: qName, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: uint32(ttl / time.Second), - }, - AAAA: ip, - }) - - case ip == nil && (qType == dns.TypeANY || qType == dns.TypeCNAME || - qType == dns.TypeA || qType == dns.TypeAAAA || qType == dns.TypeTXT): - // Get the CNAME - cnRec := &dns.CNAME{ - Hdr: dns.RR_Header{ - Name: qName, - Rrtype: dns.TypeCNAME, - Class: dns.ClassINET, - Ttl: uint32(ttl / time.Second), - }, - Target: dns.Fqdn(addr), - } - records = append(records, cnRec) - - // Recurse - more := d.resolveCNAME(cfg, cnRec.Target, maxRecursionLevel) - extra := 0 - MORE_REC: - for _, rr := range more { - switch rr.Header().Rrtype { - case dns.TypeCNAME, dns.TypeA, dns.TypeAAAA, dns.TypeTXT: - records = append(records, rr) - extra++ - if extra == maxRecurseRecords && !edns { - break MORE_REC - } - } - } - } - - if node != nil && generateMeta { - for key, value := range node.Meta { - txt := value - if !strings.HasPrefix(strings.ToLower(key), "rfc1035-") { - txt = encodeKVasRFC1464(key, value) - } - - meta = append(meta, &dns.TXT{ - Hdr: dns.RR_Header{ - Name: qName, - Rrtype: dns.TypeTXT, - Class: dns.ClassINET, - Ttl: uint32(ttl / time.Second), - }, - Txt: []string{txt}, - }) - } - } - - return records, meta -} - // indexRRs populates a map which indexes a given list of RRs by name. NOTE that // the names are all squashed to lower case so we can perform case-insensitive // lookups; the RRs are not modified. @@ -1364,25 +1269,12 @@ RPC: // serviceNodeRecords is used to add the node records for a service lookup func (d *DNSServer) serviceNodeRecords(cfg *dnsConfig, dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) { qName := req.Question[0].Name - qType := req.Question[0].Qtype handled := make(map[string]struct{}) - edns := req.IsEdns0() != nil var answerCNAME []dns.RR = nil count := 0 for _, node := range nodes { - // Start with the translated address but use the service address, - // if specified. - addr := d.agent.TranslateAddress(dc, node.Node.Address, node.Node.TaggedAddresses) - if svcAddr := d.agent.TranslateServiceAddress(dc, node.Service.Address, node.Service.TaggedAddresses); svcAddr != "" { - addr = svcAddr - } - - // If the service address is a CNAME for the service we are looking - // for then use the node address. - if qName == strings.TrimSuffix(addr, ".")+"." { - addr = node.Node.Address - } + addr := d.serviceNodeAddr(node, dc, qName) // Avoid duplicate entries, possible if a node has // the same service on multiple ports, etc. @@ -1391,18 +1283,9 @@ func (d *DNSServer) serviceNodeRecords(cfg *dnsConfig, dc string, nodes structs. } handled[addr] = struct{}{} - generateMeta := false - metaInAnswer := false - if qType == dns.TypeANY || qType == dns.TypeTXT { - generateMeta = true - metaInAnswer = true - } else if cfg.NodeMetaTXT { - generateMeta = true - } - // Add the node record had_answer := false - records, meta := d.formatNodeRecord(cfg, node.Node, addr, qName, qType, ttl, edns, maxRecursionLevel, generateMeta) + records, _ := d.nodeServiceRecords(dc, node, req, ttl, cfg, maxRecursionLevel) if records != nil { switch records[0].(type) { case *dns.CNAME: @@ -1417,13 +1300,6 @@ func (d *DNSServer) serviceNodeRecords(cfg *dnsConfig, dc string, nodes structs. } } - if meta != nil && generateMeta && metaInAnswer { - resp.Answer = append(resp.Answer, meta...) - had_answer = true - } else if meta != nil && generateMeta { - resp.Extra = append(resp.Extra, meta...) - } - if had_answer { count++ if count == cfg.ARecordLimit { @@ -1483,78 +1359,307 @@ func findWeight(node structs.CheckServiceNode) int { } } +// serviceNodeAddr is used to identify target service address +func (d *DNSServer) serviceNodeAddr(serviceNode structs.CheckServiceNode, dc string, dnsQuery string) string { + nodeAddress := d.agent.TranslateAddress(dc, serviceNode.Node.Address, serviceNode.Node.TaggedAddresses) + serviceAddress := d.agent.TranslateServiceAddress(dc, serviceNode.Service.Address, serviceNode.Service.TaggedAddresses) + addr := nodeAddress + + if serviceAddress != "" { + addr = serviceAddress + } + + // If the service address is a CNAME for the service we are looking + // for then use the node address. + if dnsQuery == strings.TrimSuffix(addr, ".")+"." { + addr = nodeAddress + } + + return addr +} + +func (d *DNSServer) encodeIPAsFqdn(dc string, ip net.IP) string { + ipv4 := ip.To4() + if ipv4 != nil { + ipStr := hex.EncodeToString(ip) + return fmt.Sprintf("%s.addr.%s.%s", ipStr[len(ipStr)-(net.IPv4len*2):], dc, d.domain) + } else { + return fmt.Sprintf("%s.addr.%s.%s", hex.EncodeToString(ip), dc, d.domain) + } +} + +func makeARecord(qType uint16, ip net.IP, ttl time.Duration) dns.RR { + + var ipRecord dns.RR + ipv4 := ip.To4() + if ipv4 != nil { + if qType == dns.TypeSRV || qType == dns.TypeA || qType == dns.TypeANY || qType == dns.TypeNS || qType == dns.TypeTXT { + ipRecord = &dns.A{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: uint32(ttl / time.Second), + }, + A: ipv4, + } + } + } else if qType == dns.TypeSRV || qType == dns.TypeAAAA || qType == dns.TypeANY || qType == dns.TypeNS || qType == dns.TypeTXT { + ipRecord = &dns.AAAA{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: uint32(ttl / time.Second), + }, + AAAA: ip, + } + } + return ipRecord +} + +// Craft dns records for a node +// In case of an SRV query the answer will be a IN SRV and additional data will store an IN A to the node IP +// Otherwise it will return a IN A record +func (d *DNSServer) makeRecordFromNode(dc string, node *structs.Node, qType uint16, qName string, ttl time.Duration, maxRecursionLevel int) []dns.RR { + addr := d.agent.TranslateAddress(node.Datacenter, node.Address, node.TaggedAddresses) + ip := net.ParseIP(addr) + + var res []dns.RR + + if ip == nil { + res = append(res, &dns.CNAME{ + Hdr: dns.RR_Header{ + Name: qName, + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + Ttl: uint32(ttl / time.Second), + }, + Target: dns.Fqdn(node.Address), + }) + + res = append(res, + d.resolveCNAME(d.config.Load().(*dnsConfig), dns.Fqdn(node.Address), maxRecursionLevel)..., + ) + + return res + } + + ipRecord := makeARecord(qType, ip, ttl) + if ipRecord == nil { + return nil + } + + ipRecord.Header().Name = qName + return []dns.RR{ipRecord} +} + +// Craft dns records for a service +// In case of an SRV query the answer will be a IN SRV and additional data will store an IN A to the node IP +// Otherwise it will return a IN A record +func (d *DNSServer) makeRecordFromServiceNode(dc string, serviceNode structs.CheckServiceNode, addr net.IP, req *dns.Msg, ttl time.Duration) ([]dns.RR, []dns.RR) { + q := req.Question[0] + ipRecord := makeARecord(q.Qtype, addr, ttl) + if ipRecord == nil { + return nil, nil + } + + if q.Qtype == dns.TypeSRV { + nodeFQDN := fmt.Sprintf("%s.node.%s.%s", serviceNode.Node.Node, dc, d.domain) + answers := []dns.RR{ + &dns.SRV{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: uint32(ttl / time.Second), + }, + Priority: 1, + Weight: uint16(findWeight(serviceNode)), + Port: uint16(d.agent.TranslateServicePort(dc, serviceNode.Service.Port, serviceNode.Service.TaggedAddresses)), + Target: nodeFQDN, + }, + } + + ipRecord.Header().Name = nodeFQDN + return answers, []dns.RR{ipRecord} + } + + ipRecord.Header().Name = q.Name + return []dns.RR{ipRecord}, nil +} + +// Craft dns records for an IP +// In case of an SRV query the answer will be a IN SRV and additional data will store an IN A to the IP +// Otherwise it will return a IN A record +func (d *DNSServer) makeRecordFromIP(dc string, addr net.IP, serviceNode structs.CheckServiceNode, req *dns.Msg, ttl time.Duration) ([]dns.RR, []dns.RR) { + q := req.Question[0] + ipRecord := makeARecord(q.Qtype, addr, ttl) + if ipRecord == nil { + return nil, nil + } + + if q.Qtype == dns.TypeSRV { + ipFQDN := d.encodeIPAsFqdn(dc, addr) + answers := []dns.RR{ + &dns.SRV{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: uint32(ttl / time.Second), + }, + Priority: 1, + Weight: uint16(findWeight(serviceNode)), + Port: uint16(d.agent.TranslateServicePort(dc, serviceNode.Service.Port, serviceNode.Service.TaggedAddresses)), + Target: ipFQDN, + }, + } + + ipRecord.Header().Name = ipFQDN + return answers, []dns.RR{ipRecord} + } + + ipRecord.Header().Name = q.Name + return []dns.RR{ipRecord}, nil +} + +// Craft dns records for an FQDN +// In case of an SRV query the answer will be a IN SRV and additional data will store an IN A to the IP +// Otherwise it will return a CNAME and a IN A record +func (d *DNSServer) makeRecordFromFQDN(dc string, fqdn string, serviceNode structs.CheckServiceNode, req *dns.Msg, ttl time.Duration, cfg *dnsConfig, maxRecursionLevel int) ([]dns.RR, []dns.RR) { + edns := req.IsEdns0() != nil + q := req.Question[0] + + more := d.resolveCNAME(cfg, dns.Fqdn(fqdn), maxRecursionLevel) + var additional []dns.RR + extra := 0 +MORE_REC: + for _, rr := range more { + switch rr.Header().Rrtype { + case dns.TypeCNAME, dns.TypeA, dns.TypeAAAA: + // set the TTL manually + rr.Header().Ttl = uint32(ttl / time.Second) + additional = append(additional, rr) + + extra++ + if extra == maxRecurseRecords && !edns { + break MORE_REC + } + } + } + + if q.Qtype == dns.TypeSRV { + answers := []dns.RR{ + &dns.SRV{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: uint32(ttl / time.Second), + }, + Priority: 1, + Weight: uint16(findWeight(serviceNode)), + Port: uint16(d.agent.TranslateServicePort(dc, serviceNode.Service.Port, serviceNode.Service.TaggedAddresses)), + Target: dns.Fqdn(fqdn), + }, + } + return answers, additional + } + + answers := []dns.RR{ + &dns.CNAME{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + Ttl: uint32(ttl / time.Second), + }, + Target: dns.Fqdn(fqdn), + }} + answers = append(answers, additional...) + + return answers, nil +} + +func (d *DNSServer) nodeServiceRecords(dc string, node structs.CheckServiceNode, req *dns.Msg, ttl time.Duration, cfg *dnsConfig, maxRecursionLevel int) ([]dns.RR, []dns.RR) { + serviceAddr := d.agent.TranslateServiceAddress(dc, node.Service.Address, node.Service.TaggedAddresses) + nodeAddr := d.agent.TranslateAddress(node.Node.Datacenter, node.Node.Address, node.Node.TaggedAddresses) + + nodeIPAddr := net.ParseIP(nodeAddr) + serviceIPAddr := net.ParseIP(serviceAddr) + + // There is no service address and the node address is an IP + if serviceAddr == "" && nodeIPAddr != nil { + if node.Node.Address != nodeAddr { + // Do not CNAME node address in case of WAN address + return d.makeRecordFromIP(dc, nodeIPAddr, node, req, ttl) + } + + return d.makeRecordFromServiceNode(dc, node, nodeIPAddr, req, ttl) + } + + // There is no service address and the node address is a FQDN (external service) + if serviceAddr == "" { + return d.makeRecordFromFQDN(dc, nodeAddr, node, req, ttl, cfg, maxRecursionLevel) + } + + // The service address is an IP + if serviceIPAddr != nil { + return d.makeRecordFromIP(dc, serviceIPAddr, node, req, ttl) + } + + // If the service address is a CNAME for the service we are looking + // for then use the node address. + if dns.Fqdn(serviceAddr) == req.Question[0].Name && nodeIPAddr != nil { + return d.makeRecordFromServiceNode(dc, node, nodeIPAddr, req, ttl) + } + + // The service address is a FQDN (external service) + return d.makeRecordFromFQDN(dc, serviceAddr, node, req, ttl, cfg, maxRecursionLevel) +} + +func (d *DNSServer) generateMeta(dc string, qName string, node *structs.Node, ttl time.Duration) []dns.RR { + var extra []dns.RR + for key, value := range node.Meta { + txt := value + if !strings.HasPrefix(strings.ToLower(key), "rfc1035-") { + txt = encodeKVasRFC1464(key, value) + } + + extra = append(extra, &dns.TXT{ + Hdr: dns.RR_Header{ + Name: qName, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: uint32(ttl / time.Second), + }, + Txt: []string{txt}, + }) + } + return extra +} + // serviceARecords is used to add the SRV records for a service lookup func (d *DNSServer) serviceSRVRecords(cfg *dnsConfig, dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) { handled := make(map[string]struct{}) - edns := req.IsEdns0() != nil for _, node := range nodes { // Avoid duplicate entries, possible if a node has // the same service the same port, etc. - tuple := fmt.Sprintf("%s:%s:%d", node.Node.Node, node.Service.Address, node.Service.Port) + serviceAddress := d.agent.TranslateServiceAddress(dc, node.Service.Address, node.Service.TaggedAddresses) + servicePort := d.agent.TranslateServicePort(dc, node.Service.Port, node.Service.TaggedAddresses) + tuple := fmt.Sprintf("%s:%s:%d", node.Node.Node, serviceAddress, servicePort) if _, ok := handled[tuple]; ok { continue } handled[tuple] = struct{}{} - weight := findWeight(node) - // Add the SRV record - srvRec := &dns.SRV{ - Hdr: dns.RR_Header{ - Name: req.Question[0].Name, - Rrtype: dns.TypeSRV, - Class: dns.ClassINET, - Ttl: uint32(ttl / time.Second), - }, - Priority: 1, - Weight: uint16(weight), - Port: uint16(d.agent.TranslateServicePort(dc, node.Service.Port, node.Service.TaggedAddresses)), - Target: fmt.Sprintf("%s.node.%s.%s", node.Node.Node, dc, d.domain), - } - resp.Answer = append(resp.Answer, srvRec) + answers, extra := d.nodeServiceRecords(dc, node, req, ttl, cfg, maxRecursionLevel) - // Start with the translated address but use the service address, - // if specified. - addr := d.agent.TranslateAddress(dc, node.Node.Address, node.Node.TaggedAddresses) - if svcAddr := d.agent.TranslateServiceAddress(dc, node.Service.Address, node.Service.TaggedAddresses); svcAddr != "" { - addr = svcAddr - } + resp.Answer = append(resp.Answer, answers...) + resp.Extra = append(resp.Extra, extra...) - // Add the extra record - records, meta := d.formatNodeRecord(cfg, node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns, maxRecursionLevel, cfg.NodeMetaTXT) - if len(records) > 0 { - // Use the node address if it doesn't differ from the service address - if addr == node.Node.Address { - resp.Extra = append(resp.Extra, records...) - } else { - // If it differs from the service address, give a special response in the - // 'addr.consul' domain with the service IP encoded in it. We have to do - // this because we can't put an IP in the target field of an SRV record. - switch record := records[0].(type) { - // IPv4 - case *dns.A: - addr := hex.EncodeToString(record.A) - - // Take the last 8 chars (4 bytes) of the encoded address to avoid junk bytes - srvRec.Target = fmt.Sprintf("%s.addr.%s.%s", addr[len(addr)-(net.IPv4len*2):], dc, d.domain) - record.Hdr.Name = srvRec.Target - resp.Extra = append(resp.Extra, record) - - // IPv6 - case *dns.AAAA: - srvRec.Target = fmt.Sprintf("%s.addr.%s.%s", hex.EncodeToString(record.AAAA), dc, d.domain) - record.Hdr.Name = srvRec.Target - resp.Extra = append(resp.Extra, record) - - // Something else (probably a CNAME; just add the records). - default: - resp.Extra = append(resp.Extra, records...) - } - } - - if meta != nil && cfg.NodeMetaTXT { - resp.Extra = append(resp.Extra, meta...) - } + if cfg.NodeMetaTXT { + resp.Extra = append(resp.Extra, d.generateMeta(dc, fmt.Sprintf("%s.node.%s.%s", node.Node.Node, dc, d.domain), node.Node, ttl)...) } } } diff --git a/agent/dns_test.go b/agent/dns_test.go index 685be283f1..30a6e5e4b7 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -1554,12 +1554,8 @@ func TestDNS_ServiceLookupWithInternalServiceAddress(t *testing.T) { } verify.Values(t, "answer", in.Answer, wantAnswer) wantExtra := []dns.RR{ - &dns.CNAME{ - Hdr: dns.RR_Header{Name: "foo.node.dc1.consul.", Rrtype: 0x5, Class: 0x1, Rdlength: 0x2}, - Target: "db.service.consul.", - }, &dns.A{ - Hdr: dns.RR_Header{Name: "db.service.consul.", Rrtype: 0x1, Class: 0x1, Rdlength: 0x4}, + Hdr: dns.RR_Header{Name: "foo.node.dc1.consul.", Rrtype: 0x1, Class: 0x1, Rdlength: 0x4}, A: []byte{0x7f, 0x0, 0x0, 0x1}, // 127.0.0.1 }, } @@ -1661,26 +1657,12 @@ func TestDNS_ExternalServiceLookup(t *testing.T) { if srvRec.Port != 12345 { t.Fatalf("Bad: %#v", srvRec) } - if srvRec.Target != "foo.node.dc1.consul." { + if srvRec.Target != "www.google.com." { t.Fatalf("Bad: %#v", srvRec) } if srvRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Answer[0]) } - - cnameRec, ok := in.Extra[0].(*dns.CNAME) - if !ok { - t.Fatalf("Bad: %#v", in.Extra[0]) - } - if cnameRec.Hdr.Name != "foo.node.dc1.consul." { - t.Fatalf("Bad: %#v", in.Extra[0]) - } - if cnameRec.Target != "www.google.com." { - t.Fatalf("Bad: %#v", in.Extra[0]) - } - if cnameRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Extra[0]) - } } } @@ -1810,43 +1792,29 @@ func TestDNS_ExternalServiceToConsulCNAMELookup(t *testing.T) { if srvRec.Port != 12345 { t.Fatalf("Bad: %#v", srvRec) } - if srvRec.Target != "alias.node.dc1.consul." { + if srvRec.Target != "web.service.consul." { t.Fatalf("Bad: %#v", srvRec) } if srvRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Answer[0]) } - if len(in.Extra) != 2 { + if len(in.Extra) != 1 { t.Fatalf("Bad: %#v", in) } - cnameRec, ok := in.Extra[0].(*dns.CNAME) + aRec, ok := in.Extra[0].(*dns.A) if !ok { t.Fatalf("Bad: %#v", in.Extra[0]) } - if cnameRec.Hdr.Name != "alias.node.dc1.consul." { - t.Fatalf("Bad: %#v", in.Extra[0]) - } - if cnameRec.Target != "web.service.consul." { - t.Fatalf("Bad: %#v", in.Extra[0]) - } - if cnameRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Extra[0]) - } - - aRec, ok := in.Extra[1].(*dns.A) - if !ok { - t.Fatalf("Bad: %#v", in.Extra[1]) - } if aRec.Hdr.Name != "web.service.consul." { - t.Fatalf("Bad: %#v", in.Extra[1]) + t.Fatalf("Bad: %#v", in.Extra[0]) } if aRec.A.String() != "127.0.0.1" { - t.Fatalf("Bad: %#v", in.Extra[1]) + t.Fatalf("Bad: %#v", in.Extra[0]) } if aRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Extra[1]) + t.Fatalf("Bad: %#v", in.Extra[0]) } } @@ -2011,14 +1979,13 @@ func TestDNS_ExternalServiceToConsulCNAMENestedLookup(t *testing.T) { if srvRec.Port != 12345 { t.Fatalf("Bad: %#v", srvRec) } - if srvRec.Target != "alias2.node.dc1.consul." { + if srvRec.Target != "alias.service.consul." { t.Fatalf("Bad: %#v", srvRec) } if srvRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Answer[0]) } - - if len(in.Extra) != 3 { + if len(in.Extra) != 2 { t.Fatalf("Bad: %#v", in) } @@ -2026,42 +1993,28 @@ func TestDNS_ExternalServiceToConsulCNAMENestedLookup(t *testing.T) { if !ok { t.Fatalf("Bad: %#v", in.Extra[0]) } - if cnameRec.Hdr.Name != "alias2.node.dc1.consul." { - t.Fatalf("Bad: %#v", in.Extra[0]) - } - if cnameRec.Target != "alias.service.consul." { - t.Fatalf("Bad: %#v", in.Extra[0]) - } - if cnameRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Extra[0]) - } - - cnameRec, ok = in.Extra[1].(*dns.CNAME) - if !ok { - t.Fatalf("Bad: %#v", in.Extra[1]) - } if cnameRec.Hdr.Name != "alias.service.consul." { - t.Fatalf("Bad: %#v", in.Extra[1]) + t.Fatalf("Bad: %#v", in.Extra[0]) } if cnameRec.Target != "web.service.consul." { - t.Fatalf("Bad: %#v", in.Extra[1]) + t.Fatalf("Bad: %#v", in.Extra[0]) } if cnameRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Extra[1]) + t.Fatalf("Bad: %#v", in.Extra[0]) } - aRec, ok := in.Extra[2].(*dns.A) + aRec, ok := in.Extra[1].(*dns.A) if !ok { - t.Fatalf("Bad: %#v", in.Extra[2]) + t.Fatalf("Bad: %#v", in.Extra[1]) } if aRec.Hdr.Name != "web.service.consul." { t.Fatalf("Bad: %#v", in.Extra[1]) } if aRec.A.String() != "127.0.0.1" { - t.Fatalf("Bad: %#v", in.Extra[2]) + t.Fatalf("Bad: %#v", in.Extra[1]) } if aRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Extra[2]) + t.Fatalf("Bad: %#v", in.Extra[1]) } } } @@ -2159,9 +2112,19 @@ func TestDNS_ServiceLookup_ServiceAddress_A(t *testing.T) { } } -func TestDNS_ServiceLookup_ServiceAddress_CNAME(t *testing.T) { +func TestDNS_ServiceLookup_ServiceAddress_SRV(t *testing.T) { t.Parallel() - a := NewTestAgent(t, t.Name(), "") + recursor := makeRecursor(t, dns.Msg{ + Answer: []dns.RR{ + dnsCNAME("www.google.com", "google.com"), + dnsA("google.com", "1.2.3.4"), + }, + }) + defer recursor.Shutdown() + + a := NewTestAgent(t, t.Name(), ` + recursors = ["`+recursor.Addr+`"] + `) defer a.Shutdown() testrpc.WaitForLeader(t, a.RPC, "dc1") @@ -2229,25 +2192,29 @@ func TestDNS_ServiceLookup_ServiceAddress_CNAME(t *testing.T) { if srvRec.Port != 12345 { t.Fatalf("Bad: %#v", srvRec) } - if srvRec.Target != "foo.node.dc1.consul." { + if srvRec.Target != "www.google.com." { t.Fatalf("Bad: %#v", srvRec) } if srvRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Answer[0]) } - cnameRec, ok := in.Extra[0].(*dns.CNAME) + // Should have google CNAME + cnRec, ok := in.Extra[0].(*dns.CNAME) if !ok { t.Fatalf("Bad: %#v", in.Extra[0]) } - if cnameRec.Hdr.Name != "foo.node.dc1.consul." { + if cnRec.Target != "google.com." { t.Fatalf("Bad: %#v", in.Extra[0]) } - if cnameRec.Target != "www.google.com." { - t.Fatalf("Bad: %#v", in.Extra[0]) + + // Check we recursively resolve + aRec, ok := in.Extra[1].(*dns.A) + if !ok { + t.Fatalf("Bad: %#v", in.Extra[1]) } - if cnameRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Extra[0]) + if aRec.A.String() != "1.2.3.4" { + t.Fatalf("Bad: %s", aRec.A.String()) } } } @@ -4591,6 +4558,104 @@ func TestDNS_ServiceLookup_CNAME(t *testing.T) { } } +func TestDNS_ServiceLookup_ServiceAddress_CNAME(t *testing.T) { + t.Parallel() + recursor := makeRecursor(t, dns.Msg{ + Answer: []dns.RR{ + dnsCNAME("www.google.com", "google.com"), + dnsA("google.com", "1.2.3.4"), + }, + }) + defer recursor.Shutdown() + + a := NewTestAgent(t, t.Name(), ` + recursors = ["`+recursor.Addr+`"] + `) + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + // Register a node with a name for an address. + { + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "google", + Address: "1.2.3.4", + Service: &structs.NodeService{ + Service: "search", + Port: 80, + Address: "www.google.com", + }, + } + + var out struct{} + if err := a.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + } + + // Register an equivalent prepared query. + var id string + { + args := &structs.PreparedQueryRequest{ + Datacenter: "dc1", + Op: structs.PreparedQueryCreate, + Query: &structs.PreparedQuery{ + Name: "test", + Service: structs.ServiceQuery{ + Service: "search", + }, + }, + } + if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + t.Fatalf("err: %v", err) + } + } + + // Look up the service directly and via prepared query. + questions := []string{ + "search.service.consul.", + id + ".query.consul.", + } + for _, question := range questions { + m := new(dns.Msg) + m.SetQuestion(question, dns.TypeANY) + + c := new(dns.Client) + in, _, err := c.Exchange(m, a.DNSAddr()) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Service CNAME, google CNAME, google A record + if len(in.Answer) != 3 { + t.Fatalf("Bad: %#v", in) + } + + // Should have service CNAME + cnRec, ok := in.Answer[0].(*dns.CNAME) + if !ok { + t.Fatalf("Bad: %#v", in.Answer[0]) + } + if cnRec.Target != "www.google.com." { + t.Fatalf("Bad: %#v", in.Answer[0]) + } + + // Should have google CNAME + cnRec, ok = in.Answer[1].(*dns.CNAME) + if !ok { + t.Fatalf("Bad: %#v", in.Answer[1]) + } + if cnRec.Target != "google.com." { + t.Fatalf("Bad: %#v", in.Answer[1]) + } + + // Check we recursively resolve + if _, ok := in.Answer[2].(*dns.A); !ok { + t.Fatalf("Bad: %#v", in.Answer[2]) + } + } +} + func TestDNS_NodeLookup_TTL(t *testing.T) { t.Parallel() recursor := makeRecursor(t, dns.Msg{ @@ -6527,25 +6592,6 @@ func TestDNSInvalidRegex(t *testing.T) { } } -func TestDNS_formatNodeRecord(t *testing.T) { - s := &DNSServer{} - - node := &structs.Node{ - Meta: map[string]string{ - "key": "value", - "key2": "value2", - }, - } - - records, meta := s.formatNodeRecord(&dnsConfig{}, node, "198.18.0.1", "test.node.consul", dns.TypeA, 5*time.Minute, false, 3, false) - require.Len(t, records, 1) - require.Len(t, meta, 0) - - records, meta = s.formatNodeRecord(&dnsConfig{}, node, "198.18.0.1", "test.node.consul", dns.TypeA, 5*time.Minute, false, 3, true) - require.Len(t, records, 1) - require.Len(t, meta, 2) -} - func TestDNS_ConfigReload(t *testing.T) { t.Parallel()