Add support for DNS config hot-reload (#4875)

The DNS config parameters `recursors` and `dns_config.*` are now hot
reloaded on SIGHUP or `consul reload` and do not need an agent restart
to be modified.
Config is stored in an atomic.Value and loaded at the beginning of each
request. Reloading only affects requests that start _after_ the
reload. Ongoing requests are not affected. To match the current
behavior the recursor handler is loaded and unloaded as needed on config
reload.
pull/5723/head
Aestek 2019-04-24 20:11:54 +02:00 committed by Matt Keeler
parent b186c3020c
commit f669bb7b0f
4 changed files with 402 additions and 169 deletions

View File

@ -3579,6 +3579,12 @@ func (a *Agent) ReloadConfig(newCfg *config.RuntimeConfig) error {
a.loadLimits(newCfg) a.loadLimits(newCfg)
for _, s := range a.dnsServers {
if err := s.ReloadConfig(newCfg); err != nil {
return fmt.Errorf("Failed reloading dns config : %v", err)
}
}
// create the config for the rpc server/client // create the config for the rpc server/client
consulCfg, err := a.consulConfig() consulCfg, err := a.consulConfig()
if err != nil { if err != nil {

View File

@ -25,7 +25,7 @@ import (
const ( const (
// UDP can fit ~25 A records in a 512B response, and ~14 AAAA // UDP can fit ~25 A records in a 512B response, and ~14 AAAA
// records. Limit further to prevent unintentional configuration // records. Limit further to prevent unintentional configuration
// abuse that would have a negative effect on application response // abuse that would have a negative effect on application response
// times. // times.
maxUDPAnswerLimit = 8 maxUDPAnswerLimit = 8
@ -46,7 +46,7 @@ type dnsSOAConfig struct {
Refresh uint32 // 3600 by default Refresh uint32 // 3600 by default
Retry uint32 // 600 Retry uint32 // 600
Expire uint32 // 86400 Expire uint32 // 86400
Minttl uint32 // 0, Minttl uint32 // 0
} }
type dnsConfig struct { type dnsConfig struct {
@ -60,128 +60,134 @@ type dnsConfig struct {
NodeTTL time.Duration NodeTTL time.Duration
OnlyPassing bool OnlyPassing bool
RecursorTimeout time.Duration RecursorTimeout time.Duration
Recursors []string
SegmentName string SegmentName string
ServiceTTL map[string]time.Duration
UDPAnswerLimit int UDPAnswerLimit int
ARecordLimit int ARecordLimit int
NodeMetaTXT bool NodeMetaTXT bool
dnsSOAConfig dnsSOAConfig SOAConfig dnsSOAConfig
// TTLRadix sets service TTLs by prefix, eg: "database-*"
TTLRadix *radix.Tree
// TTLStict sets TTLs to service by full name match. It Has higher priority than TTLRadix
TTLStrict map[string]time.Duration
DisableCompression bool
} }
// DNSServer is used to wrap an Agent and expose various // DNSServer is used to wrap an Agent and expose various
// service discovery endpoints using a DNS interface. // service discovery endpoints using a DNS interface.
type DNSServer struct { type DNSServer struct {
*dns.Server *dns.Server
agent *Agent agent *Agent
config *dnsConfig mux *dns.ServeMux
domain string domain string
recursors []string logger *log.Logger
logger *log.Logger
// Those are handling prefix lookups
ttlRadix *radix.Tree
ttlStrict map[string]time.Duration
// disableCompression is the config.DisableCompression flag that can // config stores the config as an atomic value (for hot-reloading). It is always of type *dnsConfig
// be safely changed at runtime. It always contains a bool and is config atomic.Value
// initialized with the value from config.DisableCompression.
disableCompression atomic.Value // recursorEnabled stores whever the recursor handler is enabled as an atomic flag.
// the recursor handler is only enabled if recursors are configured. This flag is used during config hot-reloading
recursorEnabled uint32
} }
func NewDNSServer(a *Agent) (*DNSServer, error) { func NewDNSServer(a *Agent) (*DNSServer, error) {
var recursors []string
for _, r := range a.config.DNSRecursors {
ra, err := recursorAddr(r)
if err != nil {
return nil, fmt.Errorf("Invalid recursor address: %v", err)
}
recursors = append(recursors, ra)
}
// Make sure domain is FQDN, make it case insensitive for ServeMux // Make sure domain is FQDN, make it case insensitive for ServeMux
domain := dns.Fqdn(strings.ToLower(a.config.DNSDomain)) domain := dns.Fqdn(strings.ToLower(a.config.DNSDomain))
dnscfg := GetDNSConfig(a.config)
srv := &DNSServer{ srv := &DNSServer{
agent: a, agent: a,
config: dnscfg, domain: domain,
domain: domain, logger: a.logger,
logger: a.logger,
recursors: recursors,
ttlRadix: radix.New(),
ttlStrict: make(map[string]time.Duration),
} }
if dnscfg.ServiceTTL != nil { cfg, err := GetDNSConfig(a.config)
for key, ttl := range dnscfg.ServiceTTL { if err != nil {
// All suffix with '*' are put in radix return nil, err
// This include '*' that will match anything
if strings.HasSuffix(key, "*") {
srv.ttlRadix.Insert(key[:len(key)-1], ttl)
} else {
srv.ttlStrict[key] = ttl
}
}
} }
srv.config.Store(cfg)
srv.disableCompression.Store(a.config.DNSDisableCompression)
return srv, nil return srv, nil
} }
// GetDNSConfig takes global config and creates the config used by DNS server // GetDNSConfig takes global config and creates the config used by DNS server
func GetDNSConfig(conf *config.RuntimeConfig) *dnsConfig { func GetDNSConfig(conf *config.RuntimeConfig) (*dnsConfig, error) {
return &dnsConfig{ cfg := &dnsConfig{
AllowStale: conf.DNSAllowStale, AllowStale: conf.DNSAllowStale,
ARecordLimit: conf.DNSARecordLimit, ARecordLimit: conf.DNSARecordLimit,
Datacenter: conf.Datacenter, Datacenter: conf.Datacenter,
EnableTruncate: conf.DNSEnableTruncate, EnableTruncate: conf.DNSEnableTruncate,
MaxStale: conf.DNSMaxStale, MaxStale: conf.DNSMaxStale,
NodeName: conf.NodeName, NodeName: conf.NodeName,
NodeTTL: conf.DNSNodeTTL, NodeTTL: conf.DNSNodeTTL,
OnlyPassing: conf.DNSOnlyPassing, OnlyPassing: conf.DNSOnlyPassing,
RecursorTimeout: conf.DNSRecursorTimeout, RecursorTimeout: conf.DNSRecursorTimeout,
SegmentName: conf.SegmentName, SegmentName: conf.SegmentName,
ServiceTTL: conf.DNSServiceTTL, UDPAnswerLimit: conf.DNSUDPAnswerLimit,
UDPAnswerLimit: conf.DNSUDPAnswerLimit, NodeMetaTXT: conf.DNSNodeMetaTXT,
NodeMetaTXT: conf.DNSNodeMetaTXT, DisableCompression: conf.DNSDisableCompression,
UseCache: conf.DNSUseCache, UseCache: conf.DNSUseCache,
CacheMaxAge: conf.DNSCacheMaxAge, CacheMaxAge: conf.DNSCacheMaxAge,
dnsSOAConfig: dnsSOAConfig{ SOAConfig: dnsSOAConfig{
Expire: conf.DNSSOA.Expire, Expire: conf.DNSSOA.Expire,
Minttl: conf.DNSSOA.Minttl, Minttl: conf.DNSSOA.Minttl,
Refresh: conf.DNSSOA.Refresh, Refresh: conf.DNSSOA.Refresh,
Retry: conf.DNSSOA.Retry, Retry: conf.DNSSOA.Retry,
}, },
} }
if conf.DNSServiceTTL != nil {
cfg.TTLRadix = radix.New()
cfg.TTLStrict = make(map[string]time.Duration)
for key, ttl := range conf.DNSServiceTTL {
// All suffix with '*' are put in radix
// This include '*' that will match anything
if strings.HasSuffix(key, "*") {
cfg.TTLRadix.Insert(key[:len(key)-1], ttl)
} else {
cfg.TTLStrict[key] = ttl
}
}
}
for _, r := range conf.DNSRecursors {
ra, err := recursorAddr(r)
if err != nil {
return nil, fmt.Errorf("Invalid recursor address: %v", err)
}
cfg.Recursors = append(cfg.Recursors, ra)
}
return cfg, nil
} }
// GetTTLForService Find the TTL for a given service. // GetTTLForService Find the TTL for a given service.
// return ttl, true if found, 0, false otherwise // return ttl, true if found, 0, false otherwise
func (d *DNSServer) GetTTLForService(service string) (time.Duration, bool) { func (cfg *dnsConfig) GetTTLForService(service string) (time.Duration, bool) {
if d.config.ServiceTTL != nil { if cfg.TTLStrict != nil {
ttl, ok := d.ttlStrict[service] ttl, ok := cfg.TTLStrict[service]
if ok { if ok {
return ttl, true return ttl, true
} }
_, ttlRaw, ok := d.ttlRadix.LongestPrefix(service) }
if cfg.TTLRadix != nil {
_, ttlRaw, ok := cfg.TTLRadix.LongestPrefix(service)
if ok { if ok {
return ttlRaw.(time.Duration), true return ttlRaw.(time.Duration), true
} }
} }
return time.Duration(0), false return 0, false
} }
func (d *DNSServer) ListenAndServe(network, addr string, notif func()) error { func (d *DNSServer) ListenAndServe(network, addr string, notif func()) error {
mux := dns.NewServeMux() cfg := d.config.Load().(*dnsConfig)
mux.HandleFunc("arpa.", d.handlePtr)
mux.HandleFunc(d.domain, d.handleQuery) d.mux = dns.NewServeMux()
if len(d.recursors) > 0 { d.mux.HandleFunc("arpa.", d.handlePtr)
mux.HandleFunc(".", d.handleRecurse) d.mux.HandleFunc(d.domain, d.handleQuery)
} d.toggleRecursorHandlerFromConfig(cfg)
d.Server = &dns.Server{ d.Server = &dns.Server{
Addr: addr, Addr: addr,
Net: network, Net: network,
Handler: mux, Handler: d.mux,
NotifyStartedFunc: notif, NotifyStartedFunc: notif,
} }
if network == "udp" { if network == "udp" {
@ -190,6 +196,34 @@ func (d *DNSServer) ListenAndServe(network, addr string, notif func()) error {
return d.Server.ListenAndServe() return d.Server.ListenAndServe()
} }
// toggleRecursorHandlerFromConfig enables or disables the recursor handler based on config idempotently
func (d *DNSServer) toggleRecursorHandlerFromConfig(cfg *dnsConfig) {
shouldEnable := len(cfg.Recursors) > 0
if shouldEnable && atomic.CompareAndSwapUint32(&d.recursorEnabled, 0, 1) {
d.mux.HandleFunc(".", d.handleRecurse)
d.logger.Println("[DEBUG] dns: recursor enabled")
return
}
if !shouldEnable && atomic.CompareAndSwapUint32(&d.recursorEnabled, 1, 0) {
d.mux.HandleRemove(".")
d.logger.Println("[DEBUG] dns: recursor disabled")
return
}
}
// ReloadConfig hot-reloads the server config with new parameters under config.RuntimeConfig.DNS*
func (d *DNSServer) ReloadConfig(newCfg *config.RuntimeConfig) error {
cfg, err := GetDNSConfig(newCfg)
if err != nil {
return err
}
d.config.Store(cfg)
d.toggleRecursorHandlerFromConfig(cfg)
return nil
}
// setEDNS is used to set the responses EDNS size headers and // setEDNS is used to set the responses EDNS size headers and
// possibly the ECS headers as well if they were present in the // possibly the ECS headers as well if they were present in the
// original request // original request
@ -258,16 +292,18 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) {
resp.RemoteAddr().Network()) resp.RemoteAddr().Network())
}(time.Now()) }(time.Now())
cfg := d.config.Load().(*dnsConfig)
// Setup the message response // Setup the message response
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(req) m.SetReply(req)
m.Compress = !d.disableCompression.Load().(bool) m.Compress = !cfg.DisableCompression
m.Authoritative = true m.Authoritative = true
m.RecursionAvailable = (len(d.recursors) > 0) m.RecursionAvailable = (len(cfg.Recursors) > 0)
// Only add the SOA if requested // Only add the SOA if requested
if req.Question[0].Qtype == dns.TypeSOA { if req.Question[0].Qtype == dns.TypeSOA {
d.addSOA(m) d.addSOA(cfg, m)
} }
datacenter := d.agent.config.Datacenter datacenter := d.agent.config.Datacenter
@ -279,7 +315,7 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) {
Datacenter: datacenter, Datacenter: datacenter,
QueryOptions: structs.QueryOptions{ QueryOptions: structs.QueryOptions{
Token: d.agent.tokens.UserToken(), Token: d.agent.tokens.UserToken(),
AllowStale: d.config.AllowStale, AllowStale: cfg.AllowStale,
}, },
} }
var out structs.IndexedNodes var out structs.IndexedNodes
@ -308,7 +344,7 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) {
Datacenter: datacenter, Datacenter: datacenter,
QueryOptions: structs.QueryOptions{ QueryOptions: structs.QueryOptions{
Token: d.agent.tokens.UserToken(), Token: d.agent.tokens.UserToken(),
AllowStale: d.config.AllowStale, AllowStale: cfg.AllowStale,
}, },
ServiceAddress: serviceAddress, ServiceAddress: serviceAddress,
} }
@ -360,25 +396,27 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
network = "tcp" network = "tcp"
} }
cfg := d.config.Load().(*dnsConfig)
// Setup the message response // Setup the message response
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(req) m.SetReply(req)
m.Compress = !d.disableCompression.Load().(bool) m.Compress = !cfg.DisableCompression
m.Authoritative = true m.Authoritative = true
m.RecursionAvailable = (len(d.recursors) > 0) m.RecursionAvailable = (len(cfg.Recursors) > 0)
ecsGlobal := true ecsGlobal := true
switch req.Question[0].Qtype { switch req.Question[0].Qtype {
case dns.TypeSOA: case dns.TypeSOA:
ns, glue := d.nameservers(req.IsEdns0() != nil, maxRecursionLevelDefault) ns, glue := d.nameservers(cfg, req.IsEdns0() != nil, maxRecursionLevelDefault)
m.Answer = append(m.Answer, d.soa()) m.Answer = append(m.Answer, d.soa(cfg))
m.Ns = append(m.Ns, ns...) m.Ns = append(m.Ns, ns...)
m.Extra = append(m.Extra, glue...) m.Extra = append(m.Extra, glue...)
m.SetRcode(req, dns.RcodeSuccess) m.SetRcode(req, dns.RcodeSuccess)
case dns.TypeNS: case dns.TypeNS:
ns, glue := d.nameservers(req.IsEdns0() != nil, maxRecursionLevelDefault) ns, glue := d.nameservers(cfg, req.IsEdns0() != nil, maxRecursionLevelDefault)
m.Answer = ns m.Answer = ns
m.Extra = glue m.Extra = glue
m.SetRcode(req, dns.RcodeSuccess) m.SetRcode(req, dns.RcodeSuccess)
@ -398,34 +436,34 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
} }
} }
func (d *DNSServer) soa() *dns.SOA { func (d *DNSServer) soa(cfg *dnsConfig) *dns.SOA {
return &dns.SOA{ return &dns.SOA{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: d.domain, Name: d.domain,
Rrtype: dns.TypeSOA, Rrtype: dns.TypeSOA,
Class: dns.ClassINET, Class: dns.ClassINET,
// Has to be consistent with MinTTL to avoid invalidation // Has to be consistent with MinTTL to avoid invalidation
Ttl: d.config.dnsSOAConfig.Minttl, Ttl: cfg.SOAConfig.Minttl,
}, },
Ns: "ns." + d.domain, Ns: "ns." + d.domain,
Serial: uint32(time.Now().Unix()), Serial: uint32(time.Now().Unix()),
Mbox: "hostmaster." + d.domain, Mbox: "hostmaster." + d.domain,
Refresh: d.config.dnsSOAConfig.Refresh, Refresh: cfg.SOAConfig.Refresh,
Retry: d.config.dnsSOAConfig.Retry, Retry: cfg.SOAConfig.Retry,
Expire: d.config.dnsSOAConfig.Expire, Expire: cfg.SOAConfig.Expire,
Minttl: d.config.dnsSOAConfig.Minttl, Minttl: cfg.SOAConfig.Minttl,
} }
} }
// addSOA is used to add an SOA record to a message for the given domain // addSOA is used to add an SOA record to a message for the given domain
func (d *DNSServer) addSOA(msg *dns.Msg) { func (d *DNSServer) addSOA(cfg *dnsConfig, msg *dns.Msg) {
msg.Ns = append(msg.Ns, d.soa()) msg.Ns = append(msg.Ns, d.soa(cfg))
} }
// nameservers returns the names and ip addresses of up to three random servers // nameservers returns the names and ip addresses of up to three random servers
// in the current cluster which serve as authoritative name servers for zone. // in the current cluster which serve as authoritative name servers for zone.
func (d *DNSServer) nameservers(edns bool, maxRecursionLevel int) (ns []dns.RR, extra []dns.RR) { func (d *DNSServer) nameservers(cfg *dnsConfig, edns bool, maxRecursionLevel int) (ns []dns.RR, extra []dns.RR) {
out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "", false, maxRecursionLevel) out, err := d.lookupServiceNodes(cfg, d.agent.config.Datacenter, structs.ConsulServiceName, "", false, maxRecursionLevel)
if err != nil { if err != nil {
d.logger.Printf("[WARN] dns: Unable to get list of servers: %s", err) d.logger.Printf("[WARN] dns: Unable to get list of servers: %s", err)
return nil, nil return nil, nil
@ -456,15 +494,15 @@ func (d *DNSServer) nameservers(edns bool, maxRecursionLevel int) (ns []dns.RR,
Name: d.domain, Name: d.domain,
Rrtype: dns.TypeNS, Rrtype: dns.TypeNS,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: uint32(d.config.NodeTTL / time.Second), Ttl: uint32(cfg.NodeTTL / time.Second),
}, },
Ns: fqdn, Ns: fqdn,
} }
ns = append(ns, nsrr) ns = append(ns, nsrr)
glue, meta := d.formatNodeRecord(nil, addr, fqdn, dns.TypeANY, d.config.NodeTTL, edns, maxRecursionLevel, d.config.NodeMetaTXT) glue, meta := d.formatNodeRecord(cfg, nil, addr, fqdn, dns.TypeANY, cfg.NodeTTL, edns, maxRecursionLevel, cfg.NodeMetaTXT)
extra = append(extra, glue...) extra = append(extra, glue...)
if meta != nil && d.config.NodeMetaTXT { if meta != nil && cfg.NodeMetaTXT {
extra = append(extra, meta...) extra = append(extra, meta...)
} }
@ -499,6 +537,8 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d
// Provide a flag for remembering whether the datacenter name was parsed already. // Provide a flag for remembering whether the datacenter name was parsed already.
var dcParsed bool var dcParsed bool
cfg := d.config.Load().(*dnsConfig)
// The last label is either "node", "service", "query", "_<protocol>", or a datacenter name // The last label is either "node", "service", "query", "_<protocol>", or a datacenter name
PARSE: PARSE:
n := len(labels) n := len(labels)
@ -531,7 +571,7 @@ PARSE:
} }
// _name._tag.service.consul // _name._tag.service.consul
d.serviceLookup(network, datacenter, labels[n-3][1:], tag, false, req, resp, maxRecursionLevel) d.serviceLookup(cfg, network, datacenter, labels[n-3][1:], tag, false, req, resp, maxRecursionLevel)
// Consul 0.3 and prior format for SRV queries // Consul 0.3 and prior format for SRV queries
} else { } else {
@ -543,7 +583,7 @@ PARSE:
} }
// tag[.tag].name.service.consul // tag[.tag].name.service.consul
d.serviceLookup(network, datacenter, labels[n-2], tag, false, req, resp, maxRecursionLevel) d.serviceLookup(cfg, network, datacenter, labels[n-2], tag, false, req, resp, maxRecursionLevel)
} }
case "connect": case "connect":
@ -552,7 +592,7 @@ PARSE:
} }
// name.connect.consul // name.connect.consul
d.serviceLookup(network, datacenter, labels[n-2], "", true, req, resp, maxRecursionLevel) d.serviceLookup(cfg, network, datacenter, labels[n-2], "", true, req, resp, maxRecursionLevel)
case "node": case "node":
if n == 1 { if n == 1 {
@ -561,7 +601,7 @@ PARSE:
// Allow a "." in the node name, just join all the parts // Allow a "." in the node name, just join all the parts
node := strings.Join(labels[:n-1], ".") node := strings.Join(labels[:n-1], ".")
d.nodeLookup(network, datacenter, node, req, resp, maxRecursionLevel) d.nodeLookup(cfg, network, datacenter, node, req, resp, maxRecursionLevel)
case "query": case "query":
if n == 1 { if n == 1 {
@ -571,7 +611,7 @@ PARSE:
// Allow a "." in the query name, just join all the parts. // Allow a "." in the query name, just join all the parts.
query := strings.Join(labels[:n-1], ".") query := strings.Join(labels[:n-1], ".")
ecsGlobal = false ecsGlobal = false
d.preparedQueryLookup(network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel) d.preparedQueryLookup(cfg, network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel)
case "addr": case "addr":
if n != 2 { if n != 2 {
@ -591,7 +631,7 @@ PARSE:
Name: qName + d.domain, Name: qName + d.domain,
Rrtype: dns.TypeA, Rrtype: dns.TypeA,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: uint32(d.config.NodeTTL / time.Second), Ttl: uint32(cfg.NodeTTL / time.Second),
}, },
A: ip, A: ip,
}) })
@ -607,7 +647,7 @@ PARSE:
Name: qName + d.domain, Name: qName + d.domain,
Rrtype: dns.TypeAAAA, Rrtype: dns.TypeAAAA,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: uint32(d.config.NodeTTL / time.Second), Ttl: uint32(cfg.NodeTTL / time.Second),
}, },
AAAA: ip, AAAA: ip,
}) })
@ -638,13 +678,13 @@ PARSE:
return return
INVALID: INVALID:
d.logger.Printf("[WARN] dns: QName invalid: %s", qName) d.logger.Printf("[WARN] dns: QName invalid: %s", qName)
d.addSOA(resp) d.addSOA(cfg, resp)
resp.SetRcode(req, dns.RcodeNameError) resp.SetRcode(req, dns.RcodeNameError)
return return
} }
// nodeLookup is used to handle a node query // nodeLookup is used to handle a node query
func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) { func (d *DNSServer) nodeLookup(cfg *dnsConfig, network, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) {
// Only handle ANY, A, AAAA, and TXT type requests // Only handle ANY, A, AAAA, and TXT type requests
qType := req.Question[0].Qtype qType := req.Question[0].Qtype
if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA && qType != dns.TypeTXT { if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA && qType != dns.TypeTXT {
@ -657,10 +697,10 @@ func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.
Node: node, Node: node,
QueryOptions: structs.QueryOptions{ QueryOptions: structs.QueryOptions{
Token: d.agent.tokens.UserToken(), Token: d.agent.tokens.UserToken(),
AllowStale: d.config.AllowStale, AllowStale: cfg.AllowStale,
}, },
} }
out, err := d.lookupNode(args) out, err := d.lookupNode(cfg, args)
if err != nil { if err != nil {
d.logger.Printf("[ERR] dns: rpc error: %v", err) d.logger.Printf("[ERR] dns: rpc error: %v", err)
resp.SetRcode(req, dns.RcodeServerFailure) resp.SetRcode(req, dns.RcodeServerFailure)
@ -669,7 +709,7 @@ func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.
// If we have no address, return not found! // If we have no address, return not found!
if out.NodeServices == nil { if out.NodeServices == nil {
d.addSOA(resp) d.addSOA(cfg, resp)
resp.SetRcode(req, dns.RcodeNameError) resp.SetRcode(req, dns.RcodeNameError)
return return
} }
@ -679,7 +719,7 @@ func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.
if qType == dns.TypeANY || qType == dns.TypeTXT { if qType == dns.TypeANY || qType == dns.TypeTXT {
generateMeta = true generateMeta = true
metaInAnswer = true metaInAnswer = true
} else if d.config.NodeMetaTXT { } else if cfg.NodeMetaTXT {
generateMeta = true generateMeta = true
} }
@ -687,21 +727,21 @@ func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.
n := out.NodeServices.Node n := out.NodeServices.Node
edns := req.IsEdns0() != nil edns := req.IsEdns0() != nil
addr := d.agent.TranslateAddress(datacenter, n.Address, n.TaggedAddresses) addr := d.agent.TranslateAddress(datacenter, n.Address, n.TaggedAddresses)
records, meta := d.formatNodeRecord(out.NodeServices.Node, addr, req.Question[0].Name, qType, d.config.NodeTTL, edns, maxRecursionLevel, generateMeta) records, meta := d.formatNodeRecord(cfg, out.NodeServices.Node, addr, req.Question[0].Name, qType, cfg.NodeTTL, edns, maxRecursionLevel, generateMeta)
if records != nil { if records != nil {
resp.Answer = append(resp.Answer, records...) resp.Answer = append(resp.Answer, records...)
} }
if meta != nil && metaInAnswer && generateMeta { if meta != nil && metaInAnswer && generateMeta {
resp.Answer = append(resp.Answer, meta...) resp.Answer = append(resp.Answer, meta...)
} else if meta != nil && generateMeta { } else if meta != nil && cfg.NodeMetaTXT {
resp.Extra = append(resp.Extra, meta...) resp.Extra = append(resp.Extra, meta...)
} }
} }
func (d *DNSServer) lookupNode(args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) { func (d *DNSServer) lookupNode(cfg *dnsConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) {
var out structs.IndexedNodeServices var out structs.IndexedNodeServices
useCache := d.config.UseCache useCache := cfg.UseCache
RPC: RPC:
if useCache { if useCache {
raw, _, err := d.agent.cache.Get(cachetype.NodeServicesName, args) raw, _, err := d.agent.cache.Get(cachetype.NodeServicesName, args)
@ -722,7 +762,7 @@ RPC:
// Verify that request is not too stale, redo the request // Verify that request is not too stale, redo the request
if args.AllowStale { if args.AllowStale {
if out.LastContact > d.config.MaxStale { if out.LastContact > cfg.MaxStale {
args.AllowStale = false args.AllowStale = false
useCache = false useCache = false
d.logger.Printf("[WARN] dns: Query results too stale, re-requesting") d.logger.Printf("[WARN] dns: Query results too stale, re-requesting")
@ -761,7 +801,7 @@ func encodeKVasRFC1464(key, value string) (txt string) {
// The return value is two slices. The first slice is the main answer slice (containing the A, AAAA, CNAME) RRs for the node // The return value is two slices. The first slice is the main answer slice (containing the A, AAAA, CNAME) RRs for the node
// and the second slice contains any TXT RRs created from the node metadata. It is up to the caller to determine where the // and the second slice contains any TXT RRs created from the node metadata. It is up to the caller to determine where the
// generated RRs should go and if they should be used at all. // generated RRs should go and if they should be used at all.
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool, maxRecursionLevel int, generateMeta bool) (records, meta []dns.RR) { func (d *DNSServer) formatNodeRecord(cfg *dnsConfig, node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool, maxRecursionLevel int, generateMeta bool) (records, meta []dns.RR) {
// Parse the IP // Parse the IP
ip := net.ParseIP(addr) ip := net.ParseIP(addr)
var ipv4 net.IP var ipv4 net.IP
@ -807,7 +847,7 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qTy
records = append(records, cnRec) records = append(records, cnRec)
// Recurse // Recurse
more := d.resolveCNAME(cnRec.Target, maxRecursionLevel) more := d.resolveCNAME(cfg, cnRec.Target, maxRecursionLevel)
extra := 0 extra := 0
MORE_REC: MORE_REC:
for _, rr := range more { for _, rr := range more {
@ -1036,21 +1076,21 @@ func trimUDPResponse(req, resp *dns.Msg, udpAnswerLimit int) (trimmed bool) {
} }
// trimDNSResponse will trim the response for UDP and TCP // trimDNSResponse will trim the response for UDP and TCP
func (d *DNSServer) trimDNSResponse(network string, req, resp *dns.Msg) (trimmed bool) { func (d *DNSServer) trimDNSResponse(cfg *dnsConfig, network string, req, resp *dns.Msg) (trimmed bool) {
if network != "tcp" { if network != "tcp" {
trimmed = trimUDPResponse(req, resp, d.config.UDPAnswerLimit) trimmed = trimUDPResponse(req, resp, cfg.UDPAnswerLimit)
} else { } else {
trimmed = d.trimTCPResponse(req, resp) trimmed = d.trimTCPResponse(req, resp)
} }
// Flag that there are more records to return in the UDP response // Flag that there are more records to return in the UDP response
if trimmed && d.config.EnableTruncate { if trimmed && cfg.EnableTruncate {
resp.Truncated = true resp.Truncated = true
} }
return trimmed return trimmed
} }
// lookupServiceNodes returns nodes with a given service. // lookupServiceNodes returns nodes with a given service.
func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect bool, maxRecursionLevel int) (structs.IndexedCheckServiceNodes, error) { func (d *DNSServer) lookupServiceNodes(cfg *dnsConfig, datacenter, service, tag string, connect bool, maxRecursionLevel int) (structs.IndexedCheckServiceNodes, error) {
args := structs.ServiceSpecificRequest{ args := structs.ServiceSpecificRequest{
Connect: connect, Connect: connect,
Datacenter: datacenter, Datacenter: datacenter,
@ -1059,14 +1099,14 @@ func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect
TagFilter: tag != "", TagFilter: tag != "",
QueryOptions: structs.QueryOptions{ QueryOptions: structs.QueryOptions{
Token: d.agent.tokens.UserToken(), Token: d.agent.tokens.UserToken(),
AllowStale: d.config.AllowStale, AllowStale: cfg.AllowStale,
MaxAge: d.config.CacheMaxAge, MaxAge: cfg.CacheMaxAge,
}, },
} }
var out structs.IndexedCheckServiceNodes var out structs.IndexedCheckServiceNodes
if d.config.UseCache { if cfg.UseCache {
raw, m, err := d.agent.cache.Get(cachetype.HealthServicesName, &args) raw, m, err := d.agent.cache.Get(cachetype.HealthServicesName, &args)
if err != nil { if err != nil {
return out, err return out, err
@ -1090,7 +1130,7 @@ func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect
} }
// redo the request the response was too stale // redo the request the response was too stale
if args.AllowStale && out.LastContact > d.config.MaxStale { if args.AllowStale && out.LastContact > cfg.MaxStale {
args.AllowStale = false args.AllowStale = false
d.logger.Printf("[WARN] dns: Query results too stale, re-requesting") d.logger.Printf("[WARN] dns: Query results too stale, re-requesting")
@ -1103,13 +1143,13 @@ func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect
// We copy the slice to avoid modifying the result if it comes from the cache // We copy the slice to avoid modifying the result if it comes from the cache
nodes := make(structs.CheckServiceNodes, len(out.Nodes)) nodes := make(structs.CheckServiceNodes, len(out.Nodes))
copy(nodes, out.Nodes) copy(nodes, out.Nodes)
out.Nodes = nodes.Filter(d.config.OnlyPassing) out.Nodes = nodes.Filter(cfg.OnlyPassing)
return out, nil return out, nil
} }
// serviceLookup is used to handle a service query // serviceLookup is used to handle a service query
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, connect bool, req, resp *dns.Msg, maxRecursionLevel int) { func (d *DNSServer) serviceLookup(cfg *dnsConfig, network, datacenter, service, tag string, connect bool, req, resp *dns.Msg, maxRecursionLevel int) {
out, err := d.lookupServiceNodes(datacenter, service, tag, connect, maxRecursionLevel) out, err := d.lookupServiceNodes(cfg, datacenter, service, tag, connect, maxRecursionLevel)
if err != nil { if err != nil {
d.logger.Printf("[ERR] dns: rpc error: %v", err) d.logger.Printf("[ERR] dns: rpc error: %v", err)
resp.SetRcode(req, dns.RcodeServerFailure) resp.SetRcode(req, dns.RcodeServerFailure)
@ -1118,7 +1158,7 @@ func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, conn
// If we have no nodes, return not found! // If we have no nodes, return not found!
if len(out.Nodes) == 0 { if len(out.Nodes) == 0 {
d.addSOA(resp) d.addSOA(cfg, resp)
resp.SetRcode(req, dns.RcodeNameError) resp.SetRcode(req, dns.RcodeNameError)
return return
} }
@ -1127,21 +1167,21 @@ func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, conn
out.Nodes.Shuffle() out.Nodes.Shuffle()
// Determine the TTL // Determine the TTL
ttl, _ := d.GetTTLForService(service) ttl, _ := cfg.GetTTLForService(service)
// Add various responses depending on the request // Add various responses depending on the request
qType := req.Question[0].Qtype qType := req.Question[0].Qtype
if qType == dns.TypeSRV { if qType == dns.TypeSRV {
d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) d.serviceSRVRecords(cfg, datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
} else { } else {
d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) d.serviceNodeRecords(cfg, datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
} }
d.trimDNSResponse(network, req, resp) d.trimDNSResponse(cfg, network, req, resp)
// If the answer is empty and the response isn't truncated, return not found // If the answer is empty and the response isn't truncated, return not found
if len(resp.Answer) == 0 && !resp.Truncated { if len(resp.Answer) == 0 && !resp.Truncated {
d.addSOA(resp) d.addSOA(cfg, resp)
return return
} }
} }
@ -1164,15 +1204,15 @@ func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET {
} }
// preparedQueryLookup is used to handle a prepared query. // preparedQueryLookup is used to handle a prepared query.
func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) { func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) {
// Execute the prepared query. // Execute the prepared query.
args := structs.PreparedQueryExecuteRequest{ args := structs.PreparedQueryExecuteRequest{
Datacenter: datacenter, Datacenter: datacenter,
QueryIDOrName: query, QueryIDOrName: query,
QueryOptions: structs.QueryOptions{ QueryOptions: structs.QueryOptions{
Token: d.agent.tokens.UserToken(), Token: d.agent.tokens.UserToken(),
AllowStale: d.config.AllowStale, AllowStale: cfg.AllowStale,
MaxAge: d.config.CacheMaxAge, MaxAge: cfg.CacheMaxAge,
}, },
// Always pass the local agent through. In the DNS interface, there // Always pass the local agent through. In the DNS interface, there
@ -1201,13 +1241,13 @@ func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remot
} }
} }
out, err := d.lookupPreparedQuery(args) out, err := d.lookupPreparedQuery(cfg, args)
// If they give a bogus query name, treat that as a name error, // If they give a bogus query name, treat that as a name error,
// not a full on server error. We have to use a string compare // not a full on server error. We have to use a string compare
// here since the RPC layer loses the type information. // here since the RPC layer loses the type information.
if err != nil && err.Error() == consul.ErrQueryNotFound.Error() { if err != nil && err.Error() == consul.ErrQueryNotFound.Error() {
d.addSOA(resp) d.addSOA(cfg, resp)
resp.SetRcode(req, dns.RcodeNameError) resp.SetRcode(req, dns.RcodeNameError)
return return
} else if err != nil { } else if err != nil {
@ -1234,13 +1274,13 @@ func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remot
if err != nil { if err != nil {
d.logger.Printf("[WARN] dns: Failed to parse TTL '%s' for prepared query '%s', ignoring", out.DNS.TTL, query) d.logger.Printf("[WARN] dns: Failed to parse TTL '%s' for prepared query '%s', ignoring", out.DNS.TTL, query)
} }
} else if d.config.ServiceTTL != nil { } else {
ttl, _ = d.GetTTLForService(out.Service) ttl, _ = cfg.GetTTLForService(out.Service)
} }
// If we have no nodes, return not found! // If we have no nodes, return not found!
if len(out.Nodes) == 0 { if len(out.Nodes) == 0 {
d.addSOA(resp) d.addSOA(cfg, resp)
resp.SetRcode(req, dns.RcodeNameError) resp.SetRcode(req, dns.RcodeNameError)
return return
} }
@ -1248,25 +1288,25 @@ func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remot
// Add various responses depending on the request. // Add various responses depending on the request.
qType := req.Question[0].Qtype qType := req.Question[0].Qtype
if qType == dns.TypeSRV { if qType == dns.TypeSRV {
d.serviceSRVRecords(out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) d.serviceSRVRecords(cfg, out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
} else { } else {
d.serviceNodeRecords(out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) d.serviceNodeRecords(cfg, out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
} }
d.trimDNSResponse(network, req, resp) d.trimDNSResponse(cfg, network, req, resp)
// If the answer is empty and the response isn't truncated, return not found // If the answer is empty and the response isn't truncated, return not found
if len(resp.Answer) == 0 && !resp.Truncated { if len(resp.Answer) == 0 && !resp.Truncated {
d.addSOA(resp) d.addSOA(cfg, resp)
return return
} }
} }
func (d *DNSServer) lookupPreparedQuery(args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) { func (d *DNSServer) lookupPreparedQuery(cfg *dnsConfig, args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) {
var out structs.PreparedQueryExecuteResponse var out structs.PreparedQueryExecuteResponse
RPC: RPC:
if d.config.UseCache { if cfg.UseCache {
raw, m, err := d.agent.cache.Get(cachetype.PreparedQueryName, &args) raw, m, err := d.agent.cache.Get(cachetype.PreparedQueryName, &args)
if err != nil { if err != nil {
return nil, err return nil, err
@ -1288,7 +1328,7 @@ RPC:
// Verify that request is not too stale, redo the request. // Verify that request is not too stale, redo the request.
if args.AllowStale { if args.AllowStale {
if out.LastContact > d.config.MaxStale { if out.LastContact > cfg.MaxStale {
args.AllowStale = false args.AllowStale = false
d.logger.Printf("[WARN] dns: Query results too stale, re-requesting") d.logger.Printf("[WARN] dns: Query results too stale, re-requesting")
goto RPC goto RPC
@ -1301,7 +1341,7 @@ RPC:
} }
// serviceNodeRecords is used to add the node records for a service lookup // serviceNodeRecords is used to add the node records for a service lookup
func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) { func (d *DNSServer) serviceNodeRecords(cfg *dnsConfig, dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) {
qName := req.Question[0].Name qName := req.Question[0].Name
qType := req.Question[0].Qtype qType := req.Question[0].Qtype
handled := make(map[string]struct{}) handled := make(map[string]struct{})
@ -1335,13 +1375,13 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode
if qType == dns.TypeANY || qType == dns.TypeTXT { if qType == dns.TypeANY || qType == dns.TypeTXT {
generateMeta = true generateMeta = true
metaInAnswer = true metaInAnswer = true
} else if d.config.NodeMetaTXT { } else if cfg.NodeMetaTXT {
generateMeta = true generateMeta = true
} }
// Add the node record // Add the node record
had_answer := false had_answer := false
records, meta := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns, maxRecursionLevel, generateMeta) records, meta := d.formatNodeRecord(cfg, node.Node, addr, qName, qType, ttl, edns, maxRecursionLevel, generateMeta)
if records != nil { if records != nil {
switch records[0].(type) { switch records[0].(type) {
case *dns.CNAME: case *dns.CNAME:
@ -1365,7 +1405,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode
if had_answer { if had_answer {
count++ count++
if count == d.config.ARecordLimit { if count == cfg.ARecordLimit {
// We stop only if greater than 0 or we reached the limit // We stop only if greater than 0 or we reached the limit
return return
} }
@ -1423,7 +1463,7 @@ func findWeight(node structs.CheckServiceNode) int {
} }
// serviceARecords is used to add the SRV records for a service lookup // serviceARecords is used to add the SRV records for a service lookup
func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) { func (d *DNSServer) serviceSRVRecords(cfg *dnsConfig, dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) {
handled := make(map[string]struct{}) handled := make(map[string]struct{})
edns := req.IsEdns0() != nil edns := req.IsEdns0() != nil
@ -1460,7 +1500,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes
} }
// Add the extra record // Add the extra record
records, meta := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns, maxRecursionLevel, d.config.NodeMetaTXT) records, meta := d.formatNodeRecord(cfg, node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns, maxRecursionLevel, cfg.NodeMetaTXT)
if len(records) > 0 { if len(records) > 0 {
// Use the node address if it doesn't differ from the service address // Use the node address if it doesn't differ from the service address
if addr == node.Node.Address { if addr == node.Node.Address {
@ -1491,7 +1531,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes
} }
} }
if meta != nil && d.config.NodeMetaTXT { if meta != nil && cfg.NodeMetaTXT {
resp.Extra = append(resp.Extra, meta...) resp.Extra = append(resp.Extra, meta...)
} }
} }
@ -1500,6 +1540,8 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes
// handleRecurse is used to handle recursive DNS queries // handleRecurse is used to handle recursive DNS queries
func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) { func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
cfg := d.config.Load().(*dnsConfig)
q := req.Question[0] q := req.Question[0]
network := "udp" network := "udp"
defer func(s time.Time) { defer func(s time.Time) {
@ -1514,11 +1556,11 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
} }
// Recursively resolve // Recursively resolve
c := &dns.Client{Net: network, Timeout: d.config.RecursorTimeout} c := &dns.Client{Net: network, Timeout: cfg.RecursorTimeout}
var r *dns.Msg var r *dns.Msg
var rtt time.Duration var rtt time.Duration
var err error var err error
for _, recursor := range d.recursors { for _, recursor := range cfg.Recursors {
r, rtt, err = c.Exchange(req, recursor) r, rtt, err = c.Exchange(req, recursor)
// Check if the response is valid and has the desired Response code // Check if the response is valid and has the desired Response code
if r != nil && (r.Rcode != dns.RcodeSuccess && r.Rcode != dns.RcodeNameError) { if r != nil && (r.Rcode != dns.RcodeSuccess && r.Rcode != dns.RcodeNameError) {
@ -1530,7 +1572,7 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
// Compress the response; we don't know if the incoming // Compress the response; we don't know if the incoming
// response was compressed or not, so by not compressing // response was compressed or not, so by not compressing
// we might generate an invalid packet on the way out. // we might generate an invalid packet on the way out.
r.Compress = !d.disableCompression.Load().(bool) r.Compress = !cfg.DisableCompression
// Forward the response // Forward the response
d.logger.Printf("[DEBUG] dns: recurse RTT for %v (%v) Recursor queried: %v", q, rtt, recursor) d.logger.Printf("[DEBUG] dns: recurse RTT for %v (%v) Recursor queried: %v", q, rtt, recursor)
@ -1547,7 +1589,7 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
q, resp.RemoteAddr().String(), resp.RemoteAddr().Network()) q, resp.RemoteAddr().String(), resp.RemoteAddr().Network())
m := &dns.Msg{} m := &dns.Msg{}
m.SetReply(req) m.SetReply(req)
m.Compress = !d.disableCompression.Load().(bool) m.Compress = !cfg.DisableCompression
m.RecursionAvailable = true m.RecursionAvailable = true
m.SetRcode(req, dns.RcodeServerFailure) m.SetRcode(req, dns.RcodeServerFailure)
if edns := req.IsEdns0(); edns != nil { if edns := req.IsEdns0(); edns != nil {
@ -1557,7 +1599,7 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
} }
// resolveCNAME is used to recursively resolve CNAME records // resolveCNAME is used to recursively resolve CNAME records
func (d *DNSServer) resolveCNAME(name string, maxRecursionLevel int) []dns.RR { func (d *DNSServer) resolveCNAME(cfg *dnsConfig, name string, maxRecursionLevel int) []dns.RR {
// If the CNAME record points to a Consul address, resolve it internally // If the CNAME record points to a Consul address, resolve it internally
// Convert query to lowercase because DNS is case insensitive; d.domain is // Convert query to lowercase because DNS is case insensitive; d.domain is
// already converted // already converted
@ -1577,7 +1619,7 @@ func (d *DNSServer) resolveCNAME(name string, maxRecursionLevel int) []dns.RR {
} }
// Do nothing if we don't have a recursor // Do nothing if we don't have a recursor
if len(d.recursors) == 0 { if len(cfg.Recursors) == 0 {
return nil return nil
} }
@ -1586,11 +1628,11 @@ func (d *DNSServer) resolveCNAME(name string, maxRecursionLevel int) []dns.RR {
m.SetQuestion(name, dns.TypeA) m.SetQuestion(name, dns.TypeA)
// Make a DNS lookup request // Make a DNS lookup request
c := &dns.Client{Net: "udp", Timeout: d.config.RecursorTimeout} c := &dns.Client{Net: "udp", Timeout: cfg.RecursorTimeout}
var r *dns.Msg var r *dns.Msg
var rtt time.Duration var rtt time.Duration
var err error var err error
for _, recursor := range d.recursors { for _, recursor := range cfg.Recursors {
r, rtt, err = c.Exchange(m, recursor) r, rtt, err = c.Exchange(m, recursor)
if err == nil { if err == nil {
d.logger.Printf("[DEBUG] dns: cname recurse RTT for %v (%v)", name, rtt) d.logger.Printf("[DEBUG] dns: cname recurse RTT for %v (%v)", name, rtt)

View File

@ -5,6 +5,7 @@ import (
"math/rand" "math/rand"
"net" "net"
"reflect" "reflect"
"sort"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -3740,6 +3741,24 @@ func TestDNS_ServiceLookup_OnlyPassing(t *testing.T) {
t.Fatalf("Bad: %#v", in.Answer[0]) t.Fatalf("Bad: %#v", in.Answer[0])
} }
} }
newCfg := *a.Config
newCfg.DNSOnlyPassing = false
err := a.ReloadConfig(&newCfg)
require.NoError(t, err)
// only_passing is now false. we should now get two nodes
m := new(dns.Msg)
m.SetQuestion("db.service.consul.", dns.TypeANY)
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
require.NoError(t, err)
require.Equal(t, 2, len(in.Answer))
ips := []string{in.Answer[0].(*dns.A).A.String(), in.Answer[1].(*dns.A).A.String()}
sort.Strings(ips)
require.Equal(t, []string{"127.0.0.1", "127.0.0.2"}, ips)
} }
func TestDNS_ServiceLookup_Randomize(t *testing.T) { func TestDNS_ServiceLookup_Randomize(t *testing.T) {
@ -5190,7 +5209,6 @@ func TestDNS_ServiceLookup_FilterACL(t *testing.T) {
}) })
} }
} }
func TestDNS_ServiceLookup_MetaTXT(t *testing.T) { func TestDNS_ServiceLookup_MetaTXT(t *testing.T) {
a := NewTestAgent(t, t.Name(), `dns_config = { enable_additional_node_meta_txt = true }`) a := NewTestAgent(t, t.Name(), `dns_config = { enable_additional_node_meta_txt = true }`)
defer a.Shutdown() defer a.Shutdown()
@ -6341,11 +6359,177 @@ func TestDNS_formatNodeRecord(t *testing.T) {
}, },
} }
records, meta := s.formatNodeRecord(node, "198.18.0.1", "test.node.consul", dns.TypeA, 5*time.Minute, false, 3, false) records, meta := s.formatNodeRecord(&dnsConfig{}, node, "198.18.0.1", "test.node.consul", dns.TypeA, 5*time.Minute, false, 3, false)
require.Len(t, records, 1) require.Len(t, records, 1)
require.Len(t, meta, 0) require.Len(t, meta, 0)
records, meta = s.formatNodeRecord(node, "198.18.0.1", "test.node.consul", dns.TypeA, 5*time.Minute, false, 3, true) records, meta = s.formatNodeRecord(&dnsConfig{}, node, "198.18.0.1", "test.node.consul", dns.TypeA, 5*time.Minute, false, 3, true)
require.Len(t, records, 1) require.Len(t, records, 1)
require.Len(t, meta, 2) require.Len(t, meta, 2)
} }
func TestDNS_ConfigReload(t *testing.T) {
t.Parallel()
a := NewTestAgent(t, t.Name(), `
recursors = ["8.8.8.8:53"]
dns_config = {
allow_stale = false
max_stale = "20s"
node_ttl = "10s"
service_ttl = {
"my_services*" = "5s"
"my_specific_service" = "30s"
}
enable_truncate = false
only_passing = false
recursor_timeout = "15s"
disable_compression = false
a_record_limit = 1
enable_additional_node_meta_txt = false
soa = {
refresh = 1
retry = 2
expire = 3
min_ttl = 4
}
}
`)
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
for _, s := range a.dnsServers {
cfg := s.config.Load().(*dnsConfig)
require.Equal(t, []string{"8.8.8.8:53"}, cfg.Recursors)
require.False(t, cfg.AllowStale)
require.Equal(t, 20*time.Second, cfg.MaxStale)
require.Equal(t, 10*time.Second, cfg.NodeTTL)
ttl, _ := cfg.GetTTLForService("my_services_1")
require.Equal(t, 5*time.Second, ttl)
ttl, _ = cfg.GetTTLForService("my_specific_service")
require.Equal(t, 30*time.Second, ttl)
require.False(t, cfg.EnableTruncate)
require.False(t, cfg.OnlyPassing)
require.Equal(t, 15*time.Second, cfg.RecursorTimeout)
require.False(t, cfg.DisableCompression)
require.Equal(t, 1, cfg.ARecordLimit)
require.False(t, cfg.NodeMetaTXT)
require.Equal(t, uint32(1), cfg.SOAConfig.Refresh)
require.Equal(t, uint32(2), cfg.SOAConfig.Retry)
require.Equal(t, uint32(3), cfg.SOAConfig.Expire)
require.Equal(t, uint32(4), cfg.SOAConfig.Minttl)
}
newCfg := *a.Config
newCfg.DNSRecursors = []string{"1.1.1.1:53"}
newCfg.DNSAllowStale = true
newCfg.DNSMaxStale = 21 * time.Second
newCfg.DNSNodeTTL = 11 * time.Second
newCfg.DNSServiceTTL = map[string]time.Duration{
"2_my_services*": 6 * time.Second,
"2_my_specific_service": 31 * time.Second,
}
newCfg.DNSEnableTruncate = true
newCfg.DNSOnlyPassing = true
newCfg.DNSRecursorTimeout = 16 * time.Second
newCfg.DNSDisableCompression = true
newCfg.DNSARecordLimit = 2
newCfg.DNSNodeMetaTXT = true
newCfg.DNSSOA.Refresh = 10
newCfg.DNSSOA.Retry = 20
newCfg.DNSSOA.Expire = 30
newCfg.DNSSOA.Minttl = 40
err := a.ReloadConfig(&newCfg)
require.NoError(t, err)
for _, s := range a.dnsServers {
cfg := s.config.Load().(*dnsConfig)
require.Equal(t, []string{"1.1.1.1:53"}, cfg.Recursors)
require.True(t, cfg.AllowStale)
require.Equal(t, 21*time.Second, cfg.MaxStale)
require.Equal(t, 11*time.Second, cfg.NodeTTL)
ttl, _ := cfg.GetTTLForService("my_services_1")
require.Equal(t, time.Duration(0), ttl)
ttl, _ = cfg.GetTTLForService("2_my_services_1")
require.Equal(t, 6*time.Second, ttl)
ttl, _ = cfg.GetTTLForService("my_specific_service")
require.Equal(t, time.Duration(0), ttl)
ttl, _ = cfg.GetTTLForService("2_my_specific_service")
require.Equal(t, 31*time.Second, ttl)
require.True(t, cfg.EnableTruncate)
require.True(t, cfg.OnlyPassing)
require.Equal(t, 16*time.Second, cfg.RecursorTimeout)
require.True(t, cfg.DisableCompression)
require.Equal(t, 2, cfg.ARecordLimit)
require.True(t, cfg.NodeMetaTXT)
require.Equal(t, uint32(10), cfg.SOAConfig.Refresh)
require.Equal(t, uint32(20), cfg.SOAConfig.Retry)
require.Equal(t, uint32(30), cfg.SOAConfig.Expire)
require.Equal(t, uint32(40), cfg.SOAConfig.Minttl)
}
}
func TestDNS_ReloadConfig_DuringQuery(t *testing.T) {
t.Parallel()
a := NewTestAgent(t, t.Name(), "")
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
m := MockPreparedQuery{
executeFn: func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error {
time.Sleep(100 * time.Millisecond)
reply.Nodes = structs.CheckServiceNodes{
{
Node: &structs.Node{
ID: "my_node",
Address: "127.0.0.1",
},
Service: &structs.NodeService{
Address: "127.0.0.1",
Port: 8080,
},
},
}
return nil
},
}
err := a.registerEndpoint("PreparedQuery", &m)
require.NoError(t, err)
{
m := new(dns.Msg)
m.SetQuestion("nope.query.consul.", dns.TypeA)
timeout := time.NewTimer(time.Second)
res := make(chan *dns.Msg)
errs := make(chan error)
go func() {
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
errs <- err
return
}
res <- in
}()
time.Sleep(50 * time.Millisecond)
// reload the config halfway through, that should not affect the ongoing query
newCfg := *a.Config
newCfg.DNSAllowStale = true
a.ReloadConfig(&newCfg)
select {
case in := <-res:
require.Equal(t, "127.0.0.1", in.Answer[0].(*dns.A).A.String())
case err := <-errs:
require.NoError(t, err)
case <-timeout.C:
require.FailNow(t, "timeout")
}
}
}

View File

@ -279,7 +279,8 @@ func (a *TestAgent) Client() *api.Client {
// DNSDisableCompression disables compression for all started DNS servers. // DNSDisableCompression disables compression for all started DNS servers.
func (a *TestAgent) DNSDisableCompression(b bool) { func (a *TestAgent) DNSDisableCompression(b bool) {
for _, srv := range a.dnsServers { for _, srv := range a.dnsServers {
srv.disableCompression.Store(b) cfg := srv.config.Load().(*dnsConfig)
cfg.DisableCompression = b
} }
} }