diff --git a/agent/dns.go b/agent/dns.go index 90020741e6..ba65ce75dc 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -236,9 +236,9 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { m.SetRcode(req, dns.RcodeSuccess) case dns.TypeNS: - ns, glue := d.nameservers() + ns, _ := d.nameservers() m.Answer = ns - m.Extra = glue + // no need to send A records with the IP address, since the ns record is a node name that resolves correctly m.SetRcode(req, dns.RcodeSuccess) default: @@ -295,7 +295,12 @@ func (d *DNSServer) nameservers() (ns []dns.RR, extra []dns.RR) { // name is "name.dc" and domain is "consul." // we want "name.node.dc.consul." lastdot := strings.LastIndexByte(name, '.') - fqdn := name[:lastdot] + ".node" + name[lastdot:] + "." + d.domain + nodeName := name[:lastdot] + if InvalidDnsRe.MatchString(nodeName) { + d.logger.Printf("[WARN] dns: Node name %q is not a valid dns host name, will not be added to NS record", nodeName) + continue + } + fqdn := nodeName + ".node" + name[lastdot:] + "." + d.domain // create a consistent, unique and sanitized name for the server fqdn = dns.Fqdn(strings.ToLower(fqdn)) diff --git a/agent/dns_test.go b/agent/dns_test.go index 1b75631844..4bf5600553 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -168,7 +168,7 @@ func TestDNS_NodeLookup(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Ns) != 2 { + if len(in.Ns) != 1 { t.Fatalf("Bad: %#v %#v", in, len(in.Answer)) } @@ -180,13 +180,6 @@ func TestDNS_NodeLookup(t *testing.T) { t.Fatalf("Bad: %#v", in.Ns[0]) } - nsRec, ok := in.Ns[1].(*dns.NS) - if !ok { - t.Fatalf("Bad: %#v", in.Ns[1]) - } - if nsRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Ns[1]) - } } func TestDNS_CaseInsensitiveNodeLookup(t *testing.T) { @@ -627,7 +620,7 @@ func TestDNS_ServiceLookup(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Ns) != 2 { + if len(in.Ns) != 1 { t.Fatalf("Bad: %#v", in) } @@ -639,13 +632,6 @@ func TestDNS_ServiceLookup(t *testing.T) { t.Fatalf("Bad: %#v", in.Ns[0]) } - nsRec, ok := in.Ns[1].(*dns.NS) - if !ok { - t.Fatalf("Bad: %#v", in.Ns[1]) - } - if nsRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Ns[1]) - } } } @@ -706,10 +692,6 @@ func TestDNS_ServiceLookupWithInternalServiceAddress(t *testing.T) { Hdr: dns.RR_Header{Name: "db.service.consul.", Rrtype: 0x1, Class: 0x1, Rdlength: 0x4}, A: []byte{0x7f, 0x0, 0x0, 0x1}, // 127.0.0.1 }, - &dns.A{ - Hdr: dns.RR_Header{Name: "server-my-test-node-dc1.consul.", Rrtype: 0x1, Class: 0x1, Rdlength: 0x4}, - A: []byte{0x7f, 0x0, 0x0, 0x1}, // 127.0.0.1 - }, } verify.Values(t, "extra", in.Extra, wantExtra) } @@ -864,7 +846,7 @@ func TestDNS_ExternalServiceToConsulCNAMELookup(t *testing.T) { t.Fatalf("Bad: %#v", in.Answer[0]) } - if len(in.Extra) != 3 { + if len(in.Extra) != 2 { t.Fatalf("Bad: %#v", in) } @@ -896,22 +878,59 @@ func TestDNS_ExternalServiceToConsulCNAMELookup(t *testing.T) { t.Fatalf("Bad: %#v", in.Extra[1]) } - aRec2, ok := in.Extra[2].(*dns.A) - if !ok { - t.Fatalf("Bad: %#v", in.Extra[2]) - } - if aRec2.Hdr.Name != "server-test-node-dc1.consul." { - t.Fatalf("Bad: %#v", in.Extra[2]) - } - if aRec2.A.String() != "127.0.0.1" { - t.Fatalf("Bad: %#v", in.Extra[2]) - } - if aRec2.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Extra[2]) - } } } +func TestDNS_NSRecords(t *testing.T) { + t.Parallel() + cfg := TestConfig() + cfg.Domain = "CONSUL." + cfg.NodeName = "foo" + a := NewTestAgent(t.Name(), cfg) + defer a.Shutdown() + + // Register node + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + TaggedAddresses: map[string]string{ + "wan": "127.0.0.2", + }, + } + + var out struct{} + if err := a.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + + m := new(dns.Msg) + m.SetQuestion("something.node.consul.", dns.TypeNS) + + c := new(dns.Client) + addr, _ := a.Config.ClientListener("", a.Config.Ports.DNS) + in, _, err := c.Exchange(m, addr.String()) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(in.Answer) != 1 { + t.Fatalf("Bad: %#v", in) + } + + nsRec, ok := in.Answer[0].(*dns.NS) + if !ok { + t.Fatalf("Bad: %#v", in.Answer[0]) + } + if nsRec.Ns != "foo.node.dc1.consul." { + t.Fatalf("Bad: %#v", in.Answer[0]) + } + if nsRec.Hdr.Ttl != 0 { + t.Fatalf("Bad: %#v", in.Answer[0]) + } + +} + func TestDNS_ExternalServiceToConsulCNAMENestedLookup(t *testing.T) { t.Parallel() cfg := TestConfig() @@ -1006,7 +1025,7 @@ func TestDNS_ExternalServiceToConsulCNAMENestedLookup(t *testing.T) { t.Fatalf("Bad: %#v", in.Answer[0]) } - if len(in.Extra) != 4 { + if len(in.Extra) != 3 { t.Fatalf("Bad: %#v", in) } @@ -1051,20 +1070,6 @@ func TestDNS_ExternalServiceToConsulCNAMENestedLookup(t *testing.T) { if aRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Extra[2]) } - - aRec2, ok := in.Extra[3].(*dns.A) - if !ok { - t.Fatalf("Bad: %#v", in.Extra[3]) - } - if aRec2.Hdr.Name != "server-test-node-dc1.consul." { - t.Fatalf("Bad: %#v", in.Extra[3]) - } - if aRec2.A.String() != "127.0.0.1" { - t.Fatalf("Bad: %#v", in.Extra[3]) - } - if aRec2.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Extra[3]) - } } } @@ -3811,7 +3816,7 @@ func TestDNS_NonExistingLookup(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Ns) != 2 { + if len(in.Ns) != 1 { t.Fatalf("Bad: %#v %#v", in, len(in.Answer)) } @@ -3822,14 +3827,6 @@ func TestDNS_NonExistingLookup(t *testing.T) { if soaRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Ns[0]) } - - nsRec, ok := in.Ns[1].(*dns.NS) - if !ok { - t.Fatalf("Bad: %#v", in.Ns[1]) - } - if nsRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Ns[1]) - } } func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) { @@ -3920,28 +3917,21 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Ns) != 2 { + if len(in.Ns) != 1 { t.Fatalf("Bad: %#v", in) } - soaRec, ok := in.Ns[1].(*dns.SOA) + soaRec, ok := in.Ns[0].(*dns.SOA) if !ok { - t.Fatalf("Bad: %#v", in.Ns[1]) + t.Fatalf("Bad: %#v", in.Ns[0]) } if soaRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Ns[1]) + t.Fatalf("Bad: %#v", in.Ns[0]) } if in.Rcode != dns.RcodeSuccess { t.Fatalf("Bad: %#v", in) } - nsRec, ok := in.Ns[0].(*dns.NS) - if !ok { - t.Fatalf("Bad: %#v", in.Ns[0]) - } - if nsRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Ns[0]) - } } // Check for ipv4 records on ipv6-only service directly and via the @@ -3961,24 +3951,16 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Ns) != 2 { + if len(in.Ns) != 1 { t.Fatalf("Bad: %#v", in) } - nsRec, ok := in.Ns[0].(*dns.NS) + soaRec, ok := in.Ns[0].(*dns.SOA) if !ok { t.Fatalf("Bad: %#v", in.Ns[0]) } - if nsRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Ns[0]) - } - - soaRec, ok := in.Ns[1].(*dns.SOA) - if !ok { - t.Fatalf("Bad: %#v", in.Ns[1]) - } if soaRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Ns[1]) + t.Fatalf("Bad: %#v", in.Ns[0]) } if in.Rcode != dns.RcodeSuccess { @@ -4020,7 +4002,7 @@ func TestDNS_PreparedQuery_AllowStale(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Ns) != 2 { + if len(in.Ns) != 1 { t.Fatalf("Bad: %#v", in) } @@ -4032,14 +4014,6 @@ func TestDNS_PreparedQuery_AllowStale(t *testing.T) { t.Fatalf("Bad: %#v", in.Ns[0]) } - nsRec, ok := in.Ns[1].(*dns.NS) - if !ok { - t.Fatalf("Bad: %#v", in.Ns[1]) - } - if nsRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Ns[1]) - } - } } @@ -4067,7 +4041,7 @@ func TestDNS_InvalidQueries(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Ns) != 2 { + if len(in.Ns) != 1 { t.Fatalf("Bad: %#v", in) } @@ -4079,13 +4053,6 @@ func TestDNS_InvalidQueries(t *testing.T) { t.Fatalf("Bad: %#v", in.Ns[0]) } - nsRec, ok := in.Ns[1].(*dns.NS) - if !ok { - t.Fatalf("Bad: %#v", in.Ns[1]) - } - if nsRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Ns[1]) - } } } @@ -4781,3 +4748,24 @@ func TestDNS_Compression_Recurse(t *testing.T) { t.Fatalf("doesn't look compressed: %d vs. %d", compressed, unc) } } + +func TestDNSInvalidRegex(t *testing.T) { + tests := []struct { + desc string + in string + invalid bool + }{ + {"Valid Hostname", "testnode", false}, + {"Valid Hostname", "test-node", false}, + {"Invalid Hostname with special chars", "test#$$!node", true}, + {"Invalid Hostname with special chars in the end", "test-node%^", true}, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + if got, want := InvalidDnsRe.MatchString(test.in), test.invalid; got != want { + t.Fatalf("Expected %v to return %v", test.in, want) + } + }) + + } +}