From cf7e9e40d5174423116d3c9aa62c01d473c979a0 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Mon, 3 Nov 2014 11:40:55 -0800 Subject: [PATCH] Fixing unit tests --- command/agent/dns.go | 76 +++++++++++++++++++++----------------------- 1 file changed, 37 insertions(+), 39 deletions(-) diff --git a/command/agent/dns.go b/command/agent/dns.go index 0e5f8f195e..63810913c2 100644 --- a/command/agent/dns.go +++ b/command/agent/dns.go @@ -2,14 +2,15 @@ package agent import ( "fmt" - "github.com/hashicorp/consul/consul/structs" - "github.com/miekg/dns" "io" "log" "math/rand" "net" "strings" "time" + + "github.com/hashicorp/consul/consul/structs" + "github.com/miekg/dns" ) const ( @@ -71,15 +72,14 @@ func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain s mux.HandleFunc(consulDomain, srv.handleTest) } if len(recursors) > 0 { - validatedRecursors := []string{} + validatedRecursors := make([]string, len(recursors)) - for _, recursor := range recursors { + for idx, recursor := range recursors { recursor, err := recursorAddr(recursor) if err != nil { return nil, fmt.Errorf("Invalid recursor address: %v", err) } - - validatedRecursors = append(validatedRecursors, recursor) + validatedRecursors[idx] = recursor } srv.recursors = validatedRecursors @@ -594,34 +594,35 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) { // Recursively resolve c := &dns.Client{Net: network} - for i,recursor := range d.recursors { - r, rtt, err := c.Exchange(req, recursor) - - if i < len(d.recursors) && err != nil { - continue - } else if err != nil { - // On all of failure, return a SERVFAIL message - d.logger.Printf("[ERR] dns: recurse failed: %v", err) - m := &dns.Msg{} - m.SetReply(req) - m.RecursionAvailable = true - m.SetRcode(req, dns.RcodeServerFailure) - resp.WriteMsg(m) + var r *dns.Msg + var rtt time.Duration + var err error + for _, recursor := range d.recursors { + r, rtt, err = c.Exchange(req, recursor) + if err == nil { + // Forward the response + d.logger.Printf("[DEBUG] dns: recurse RTT for %v (%v)", q, rtt) + if err := resp.WriteMsg(r); err != nil { + d.logger.Printf("[WARN] dns: failed to respond: %v", err) + } return } - d.logger.Printf("[DEBUG] dns: recurse RTT for %v (%v)", q, rtt) - - // Forward the response - if err := resp.WriteMsg(r); err != nil { - d.logger.Printf("[WARN] dns: failed to respond: %v", err) - } + d.logger.Printf("[ERR] dns: recurse failed: %v", err) } + + // If all resolvers fail, return a SERVFAIL message + d.logger.Printf("[ERR] dns: all resolvers failed for %v", q) + m := &dns.Msg{} + m.SetReply(req) + m.RecursionAvailable = true + m.SetRcode(req, dns.RcodeServerFailure) + resp.WriteMsg(m) } // resolveCNAME is used to recursively resolve CNAME records func (d *DNSServer) resolveCNAME(name string) []dns.RR { // Do nothing if we don't have a recursor - if len(d.recursors) > 0 { + if len(d.recursors) == 0 { return nil } @@ -631,20 +632,17 @@ func (d *DNSServer) resolveCNAME(name string) []dns.RR { // Make a DNS lookup request c := &dns.Client{Net: "udp"} - for i,recursor := range d.recursors { - r, rtt, err := c.Exchange(m, recursor) - - if i < len(d.recursors) && err != nil { - continue - } else if err != nil { - d.logger.Printf("[ERR] dns: cname recurse failed: %v", err) - return nil + var r *dns.Msg + var rtt time.Duration + var err error + for _, recursor := range d.recursors { + r, rtt, err = c.Exchange(m, recursor) + if err == nil { + d.logger.Printf("[DEBUG] dns: cname recurse RTT for %v (%v)", name, rtt) + return r.Answer } - d.logger.Printf("[DEBUG] dns: cname recurse RTT for %v (%v)", name, rtt) - - // Return all the answers - return r.Answer + d.logger.Printf("[ERR] dns: cname recurse failed for %v: %v", name, err) } - + d.logger.Printf("[ERR] dns: all resolvers failed for %v", name) return nil }