Fixing unit tests

pull/453/head
Armon Dadgar 2014-11-03 11:40:55 -08:00
parent cd936793ad
commit cf7e9e40d5
1 changed files with 37 additions and 39 deletions

View File

@ -2,14 +2,15 @@ package agent
import ( import (
"fmt" "fmt"
"github.com/hashicorp/consul/consul/structs"
"github.com/miekg/dns"
"io" "io"
"log" "log"
"math/rand" "math/rand"
"net" "net"
"strings" "strings"
"time" "time"
"github.com/hashicorp/consul/consul/structs"
"github.com/miekg/dns"
) )
const ( const (
@ -71,15 +72,14 @@ func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain s
mux.HandleFunc(consulDomain, srv.handleTest) mux.HandleFunc(consulDomain, srv.handleTest)
} }
if len(recursors) > 0 { if len(recursors) > 0 {
validatedRecursors := []string{} validatedRecursors := make([]string, len(recursors))
for _, recursor := range recursors { for idx, recursor := range recursors {
recursor, err := recursorAddr(recursor) recursor, err := recursorAddr(recursor)
if err != nil { if err != nil {
return nil, fmt.Errorf("Invalid recursor address: %v", err) return nil, fmt.Errorf("Invalid recursor address: %v", err)
} }
validatedRecursors[idx] = recursor
validatedRecursors = append(validatedRecursors, recursor)
} }
srv.recursors = validatedRecursors srv.recursors = validatedRecursors
@ -594,34 +594,35 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
// Recursively resolve // Recursively resolve
c := &dns.Client{Net: network} c := &dns.Client{Net: network}
for i,recursor := range d.recursors { var r *dns.Msg
r, rtt, err := c.Exchange(req, recursor) var rtt time.Duration
var err error
if i < len(d.recursors) && err != nil { for _, recursor := range d.recursors {
continue r, rtt, err = c.Exchange(req, recursor)
} else if err != nil { if err == nil {
// On all of failure, return a SERVFAIL message // Forward the response
d.logger.Printf("[DEBUG] dns: recurse RTT for %v (%v)", q, rtt)
if err := resp.WriteMsg(r); err != nil {
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
}
return
}
d.logger.Printf("[ERR] dns: recurse failed: %v", err) d.logger.Printf("[ERR] dns: recurse failed: %v", err)
}
// If all resolvers fail, return a SERVFAIL message
d.logger.Printf("[ERR] dns: all resolvers failed for %v", q)
m := &dns.Msg{} m := &dns.Msg{}
m.SetReply(req) m.SetReply(req)
m.RecursionAvailable = true m.RecursionAvailable = true
m.SetRcode(req, dns.RcodeServerFailure) m.SetRcode(req, dns.RcodeServerFailure)
resp.WriteMsg(m) resp.WriteMsg(m)
return
}
d.logger.Printf("[DEBUG] dns: recurse RTT for %v (%v)", q, rtt)
// Forward the response
if err := resp.WriteMsg(r); err != nil {
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
}
}
} }
// resolveCNAME is used to recursively resolve CNAME records // resolveCNAME is used to recursively resolve CNAME records
func (d *DNSServer) resolveCNAME(name string) []dns.RR { func (d *DNSServer) resolveCNAME(name string) []dns.RR {
// Do nothing if we don't have a recursor // Do nothing if we don't have a recursor
if len(d.recursors) > 0 { if len(d.recursors) == 0 {
return nil return nil
} }
@ -631,20 +632,17 @@ func (d *DNSServer) resolveCNAME(name string) []dns.RR {
// Make a DNS lookup request // Make a DNS lookup request
c := &dns.Client{Net: "udp"} c := &dns.Client{Net: "udp"}
for i,recursor := range d.recursors { var r *dns.Msg
r, rtt, err := c.Exchange(m, recursor) var rtt time.Duration
var err error
if i < len(d.recursors) && err != nil { for _, recursor := range d.recursors {
continue r, rtt, err = c.Exchange(m, recursor)
} else if err != nil { if err == nil {
d.logger.Printf("[ERR] dns: cname recurse failed: %v", err)
return nil
}
d.logger.Printf("[DEBUG] dns: cname recurse RTT for %v (%v)", name, rtt) d.logger.Printf("[DEBUG] dns: cname recurse RTT for %v (%v)", name, rtt)
// Return all the answers
return r.Answer return r.Answer
} }
d.logger.Printf("[ERR] dns: cname recurse failed for %v: %v", name, err)
}
d.logger.Printf("[ERR] dns: all resolvers failed for %v", name)
return nil return nil
} }