diff --git a/command/agent/catalog_endpoint.go b/command/agent/catalog_endpoint.go index 5aa0b90ecf..f7454e12ae 100644 --- a/command/agent/catalog_endpoint.go +++ b/command/agent/catalog_endpoint.go @@ -144,8 +144,8 @@ func (s *HTTPServer) CatalogNodeServices(resp http.ResponseWriter, req *http.Req } // Make the RPC request - var out structs.NodeServices - if err := s.agent.RPC("Catalog.NodeServices", &args, &out); err != nil { + out := new(structs.NodeServices) + if err := s.agent.RPC("Catalog.NodeServices", &args, out); err != nil { return nil, err } return out, nil diff --git a/command/agent/catalog_endpoint_test.go b/command/agent/catalog_endpoint_test.go index a68c55914d..574253575e 100644 --- a/command/agent/catalog_endpoint_test.go +++ b/command/agent/catalog_endpoint_test.go @@ -232,8 +232,8 @@ func TestCatalogNodeServices(t *testing.T) { t.Fatalf("err: %v", err) } - services := obj.(structs.NodeServices) - if len(services) != 1 { + services := obj.(*structs.NodeServices) + if len(services.Services) != 1 { t.Fatalf("bad: %v", obj) } } diff --git a/command/agent/dns.go b/command/agent/dns.go index 44eac097a5..be0f74feee 100644 --- a/command/agent/dns.go +++ b/command/agent/dns.go @@ -2,23 +2,35 @@ package agent import ( "fmt" + "github.com/hashicorp/consul/consul/structs" "github.com/miekg/dns" "io" "log" + "net" + "strings" "time" ) +const ( + testQuery = "_test.consul." + consulDomain = "consul." +) + // DNSServer is used to wrap an Agent and expose various // service discovery endpoints using a DNS interface. type DNSServer struct { agent *Agent dnsHandler *dns.ServeMux dnsServer *dns.Server + domain string logger *log.Logger } // NewDNSServer starts a new DNS server to provide an agent interface func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSServer, error) { + // Make sure domain is FQDN + domain = dns.Fqdn(domain) + // Construct the DNS components mux := dns.NewServeMux() @@ -35,11 +47,15 @@ func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSS agent: agent, dnsHandler: mux, dnsServer: server, + domain: domain, logger: log.New(logOutput, "", log.LstdFlags), } - // Register mux handlers - mux.HandleFunc("consul.", srv.handleConsul) + // Register mux handlers, always handle "consul." + mux.HandleFunc(domain, srv.handleQuery) + if domain != consulDomain { + mux.HandleFunc(consulDomain, srv.handleTest) + } // Async start the DNS Server, handle a potential error errCh := make(chan error, 1) @@ -57,7 +73,7 @@ func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSS time.Sleep(50 * time.Millisecond) m := new(dns.Msg) - m.SetQuestion("_test.consul.", dns.TypeANY) + m.SetQuestion(testQuery, dns.TypeANY) c := new(dns.Client) in, _, err := c.Exchange(m, bind) @@ -85,12 +101,41 @@ func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSS return srv, nil } -// handleConsul is used to handle DNS queries in the ".consul." domain -func (d *DNSServer) handleConsul(resp dns.ResponseWriter, req *dns.Msg) { +// handleQUery is used to handle DNS queries in the configured domain +func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { + q := req.Question[0] + defer func(s time.Time) { + d.logger.Printf("[DEBUG] dns: request for %v (%v)", q, time.Now().Sub(s)) + }(time.Now()) + + // Check if this is potentially a test query + if q.Name == testQuery { + d.handleTest(resp, req) + return + } + + // Setup the message response + m := new(dns.Msg) + m.SetReply(req) + m.Authoritative = true + d.addSOA(d.domain, m) + defer resp.WriteMsg(m) + + // Dispatch the correct handler + d.dispatch(req, m) +} + +// handleTest is used to handle DNS queries in the ".consul." domain +func (d *DNSServer) handleTest(resp dns.ResponseWriter, req *dns.Msg) { q := req.Question[0] - d.logger.Printf("[DEBUG] dns: request for %v", q) + defer func(s time.Time) { + d.logger.Printf("[DEBUG] dns: request for %v (%v)", q, time.Now().Sub(s)) + }(time.Now()) - if q.Qtype != dns.TypeANY && q.Qtype != dns.TypeTXT { + if !(q.Qtype == dns.TypeANY || q.Qtype == dns.TypeTXT) { + return + } + if q.Name != testQuery { return } @@ -101,7 +146,7 @@ func (d *DNSServer) handleConsul(resp dns.ResponseWriter, req *dns.Msg) { header := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0} txt := &dns.TXT{header, []string{"ok"}} m.Answer = append(m.Answer, txt) - d.addSOA("consul.", m) + d.addSOA(consulDomain, m) resp.WriteMsg(m) } @@ -124,3 +169,103 @@ func (d *DNSServer) addSOA(domain string, msg *dns.Msg) { } msg.Ns = append(msg.Ns, soa) } + +// dispatch is used to parse a request and invoke the correct handler +func (d *DNSServer) dispatch(req, resp *dns.Msg) { + // By default the query is in the default datacenter + datacenter := d.agent.config.Datacenter + + // Get the QName without the domain suffix + qName := dns.Fqdn(req.Question[0].Name) + qName = strings.TrimSuffix(qName, d.domain) + + // Split into the label parts + labels := dns.SplitDomainName(qName) + + // The last label is either "node", "service" or a datacenter name +PARSE: + if len(labels) == 0 { + goto INVALID + } + switch labels[len(labels)-1] { + case "service": + // Handle lookup with and without tag + switch len(labels) { + case 2: + d.serviceLookup(datacenter, labels[0], "", req, resp) + case 3: + d.serviceLookup(datacenter, labels[1], labels[0], req, resp) + default: + goto INVALID + } + + case "node": + if len(labels) != 2 { + goto INVALID + } + d.nodeLookup(datacenter, labels[0], req, resp) + + default: + // Store the DC, and re-parse + datacenter = labels[len(labels)-1] + labels = labels[:len(labels)-1] + goto PARSE + } + return +INVALID: + d.logger.Printf("[WARN] dns: QName invalid: %s", qName) + resp.SetRcode(req, dns.RcodeNameError) +} + +// nodeLookup is used to handle a node query +func (d *DNSServer) nodeLookup(datacenter, node string, req, resp *dns.Msg) { + // Only handle ANY and A type requests + qType := req.Question[0].Qtype + if qType != dns.TypeANY && qType != dns.TypeA { + return + } + + // Make an RPC request + args := structs.NodeServicesRequest{ + Datacenter: datacenter, + Node: node, + } + var out structs.NodeServices + if err := d.agent.RPC("Catalog.NodeServices", &args, &out); err != nil { + d.logger.Printf("[ERR] dns: rpc error: %v", err) + resp.SetRcode(req, dns.RcodeServerFailure) + return + } + + // If we have no address, return not found! + if out.Address == "" { + resp.SetRcode(req, dns.RcodeNameError) + return + } + + // Parse the IP + ip := net.ParseIP(out.Address) + if ip == nil { + d.logger.Printf("[ERR] dns: failed to parse IP %v for %v", out.Address, node) + resp.SetRcode(req, dns.RcodeServerFailure) + return + } + + // Format A record + aRec := &dns.A{ + Hdr: dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 0, + }, + A: ip, + } + + // Add the response + resp.Answer = append(resp.Answer, aRec) +} + +// serviceLookup is used to handle a service query +func (d *DNSServer) serviceLookup(datacenter, service, tag string, req, resp *dns.Msg) { +} diff --git a/command/agent/dns_test.go b/command/agent/dns_test.go index c0892a4728..8cac86da2f 100644 --- a/command/agent/dns_test.go +++ b/command/agent/dns_test.go @@ -1,9 +1,11 @@ package agent import ( + "github.com/hashicorp/consul/consul/structs" "github.com/miekg/dns" "os" "testing" + "time" ) func makeDNSServer(t *testing.T) (string, *DNSServer) { @@ -42,3 +44,66 @@ func TestDNS_IsAlive(t *testing.T) { t.Fatalf("Bad: %#v", in.Answer[0]) } } + +func TestDNS_NodeLookup(t *testing.T) { + dir, srv := makeDNSServer(t) + defer os.RemoveAll(dir) + defer srv.agent.Shutdown() + + // Wait for leader + time.Sleep(100 * time.Millisecond) + + // Register node + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + } + var out struct{} + if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + + m := new(dns.Msg) + m.SetQuestion("foo.node.consul.", dns.TypeANY) + + c := new(dns.Client) + in, _, err := c.Exchange(m, srv.agent.config.DNSAddr) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(in.Answer) != 1 { + t.Fatalf("Bad: %#v", in) + } + + aRec, ok := in.Answer[0].(*dns.A) + if !ok { + t.Fatalf("Bad: %#v", in.Answer[0]) + } + if aRec.A.String() != "127.0.0.1" { + t.Fatalf("Bad: %#v", in.Answer[0]) + } + + // Re-do the query, but specify the DC + m = new(dns.Msg) + m.SetQuestion("foo.node.dc1.consul.", dns.TypeANY) + + c = new(dns.Client) + in, _, err = c.Exchange(m, srv.agent.config.DNSAddr) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(in.Answer) != 1 { + t.Fatalf("Bad: %#v", in) + } + + aRec, ok = in.Answer[0].(*dns.A) + if !ok { + t.Fatalf("Bad: %#v", in.Answer[0]) + } + if aRec.A.String() != "127.0.0.1" { + t.Fatalf("Bad: %#v", in.Answer[0]) + } +} diff --git a/command/agent/http.go b/command/agent/http.go index b380494dad..50f8df7aae 100644 --- a/command/agent/http.go +++ b/command/agent/http.go @@ -72,14 +72,14 @@ func (s *HTTPServer) wrap(handler func(resp http.ResponseWriter, req *http.Reque // Invoke the handler start := time.Now() defer func() { - s.logger.Printf("[DEBUG] HTTP Request %v (%v)", req.URL, time.Now().Sub(start)) + s.logger.Printf("[DEBUG] http: Request %v (%v)", req.URL, time.Now().Sub(start)) }() obj, err := handler(resp, req) // Check for an error HAS_ERR: if err != nil { - s.logger.Printf("[ERR] Request %v, error: %v", req.URL, err) + s.logger.Printf("[ERR] http: Request %v, error: %v", req.URL, err) resp.WriteHeader(500) resp.Write([]byte(err.Error())) return