|
|
|
@ -28,12 +28,12 @@ type DNSServer struct {
|
|
|
|
|
dnsServer *dns.Server |
|
|
|
|
dnsServerTCP *dns.Server |
|
|
|
|
domain string |
|
|
|
|
recursor string |
|
|
|
|
recursors []string |
|
|
|
|
logger *log.Logger |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// NewDNSServer starts a new DNS server to provide an agent interface
|
|
|
|
|
func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain, bind, recursor string) (*DNSServer, error) { |
|
|
|
|
func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain string, bind string, recursors []string) (*DNSServer, error) { |
|
|
|
|
// Make sure domain is FQDN
|
|
|
|
|
domain = dns.Fqdn(domain) |
|
|
|
|
|
|
|
|
@ -61,7 +61,7 @@ func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain,
|
|
|
|
|
dnsServer: server, |
|
|
|
|
dnsServerTCP: serverTCP, |
|
|
|
|
domain: domain, |
|
|
|
|
recursor: recursor, |
|
|
|
|
recursors: recursors, |
|
|
|
|
logger: log.New(logOutput, "", log.LstdFlags), |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -70,12 +70,19 @@ func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain,
|
|
|
|
|
if domain != consulDomain { |
|
|
|
|
mux.HandleFunc(consulDomain, srv.handleTest) |
|
|
|
|
} |
|
|
|
|
if recursor != "" { |
|
|
|
|
recursor, err := recursorAddr(recursor) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, fmt.Errorf("Invalid recursor address: %v", err) |
|
|
|
|
if len(recursors) > 0 { |
|
|
|
|
validatedRecursors := []string{} |
|
|
|
|
|
|
|
|
|
for _, recursor := range recursors { |
|
|
|
|
recursor, err := recursorAddr(recursor) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, fmt.Errorf("Invalid recursor address: %v", err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
validatedRecursors = append(validatedRecursors, recursor) |
|
|
|
|
} |
|
|
|
|
srv.recursor = recursor |
|
|
|
|
|
|
|
|
|
srv.recursors = validatedRecursors |
|
|
|
|
mux.HandleFunc(".", srv.handleRecurse) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -178,7 +185,7 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
|
|
|
|
|
m := new(dns.Msg) |
|
|
|
|
m.SetReply(req) |
|
|
|
|
m.Authoritative = true |
|
|
|
|
m.RecursionAvailable = (d.recursor != "") |
|
|
|
|
m.RecursionAvailable = (len(d.recursors) > 0) |
|
|
|
|
|
|
|
|
|
// Only add the SOA if requested
|
|
|
|
|
if req.Question[0].Qtype == dns.TypeSOA { |
|
|
|
@ -587,30 +594,34 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
|
|
|
|
|
|
|
|
|
|
// Recursively resolve
|
|
|
|
|
c := &dns.Client{Net: network} |
|
|
|
|
r, rtt, err := c.Exchange(req, d.recursor) |
|
|
|
|
for i,recursor := range d.recursors { |
|
|
|
|
r, rtt, err := c.Exchange(req, recursor) |
|
|
|
|
|
|
|
|
|
// On failure, return a SERVFAIL message
|
|
|
|
|
if err != nil { |
|
|
|
|
d.logger.Printf("[ERR] dns: recurse failed: %v", err) |
|
|
|
|
m := &dns.Msg{} |
|
|
|
|
m.SetReply(req) |
|
|
|
|
m.RecursionAvailable = true |
|
|
|
|
m.SetRcode(req, dns.RcodeServerFailure) |
|
|
|
|
resp.WriteMsg(m) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
d.logger.Printf("[DEBUG] dns: recurse RTT for %v (%v)", q, rtt) |
|
|
|
|
if i < len(d.recursors) && err != nil { |
|
|
|
|
continue |
|
|
|
|
} else if err != nil { |
|
|
|
|
// On all of failure, return a SERVFAIL message
|
|
|
|
|
d.logger.Printf("[ERR] dns: recurse failed: %v", err) |
|
|
|
|
m := &dns.Msg{} |
|
|
|
|
m.SetReply(req) |
|
|
|
|
m.RecursionAvailable = true |
|
|
|
|
m.SetRcode(req, dns.RcodeServerFailure) |
|
|
|
|
resp.WriteMsg(m) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
d.logger.Printf("[DEBUG] dns: recurse RTT for %v (%v)", q, rtt) |
|
|
|
|
|
|
|
|
|
// Forward the response
|
|
|
|
|
if err := resp.WriteMsg(r); err != nil { |
|
|
|
|
d.logger.Printf("[WARN] dns: failed to respond: %v", err) |
|
|
|
|
// Forward the response
|
|
|
|
|
if err := resp.WriteMsg(r); err != nil { |
|
|
|
|
d.logger.Printf("[WARN] dns: failed to respond: %v", err) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// resolveCNAME is used to recursively resolve CNAME records
|
|
|
|
|
func (d *DNSServer) resolveCNAME(name string) []dns.RR { |
|
|
|
|
// Do nothing if we don't have a recursor
|
|
|
|
|
if d.recursor == "" { |
|
|
|
|
if len(d.recursors) > 0 { |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -620,13 +631,20 @@ func (d *DNSServer) resolveCNAME(name string) []dns.RR {
|
|
|
|
|
|
|
|
|
|
// Make a DNS lookup request
|
|
|
|
|
c := &dns.Client{Net: "udp"} |
|
|
|
|
r, rtt, err := c.Exchange(m, d.recursor) |
|
|
|
|
if err != nil { |
|
|
|
|
d.logger.Printf("[ERR] dns: cname recurse failed: %v", err) |
|
|
|
|
return nil |
|
|
|
|
for i,recursor := range d.recursors { |
|
|
|
|
r, rtt, err := c.Exchange(m, recursor) |
|
|
|
|
|
|
|
|
|
if i < len(d.recursors) && err != nil { |
|
|
|
|
continue |
|
|
|
|
} else if err != nil { |
|
|
|
|
d.logger.Printf("[ERR] dns: cname recurse failed: %v", err) |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
d.logger.Printf("[DEBUG] dns: cname recurse RTT for %v (%v)", name, rtt) |
|
|
|
|
|
|
|
|
|
// Return all the answers
|
|
|
|
|
return r.Answer |
|
|
|
|
} |
|
|
|
|
d.logger.Printf("[DEBUG] dns: cname recurse RTT for %v (%v)", name, rtt) |
|
|
|
|
|
|
|
|
|
// Return all the answers
|
|
|
|
|
return r.Answer |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|