diff --git a/agent/dns.go b/agent/dns.go index 72e6223246..a9063e26f4 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -592,8 +592,7 @@ func (d *DNSServer) parseDatacenter(labels []string, datacenter *string) bool { // doDispatch is used to parse a request and invoke the correct handler. // parameter maxRecursionLevel will handle whether recursive call can be performed -func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) (ecsGlobal bool) { - ecsGlobal = true +func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) bool { // By default the query is in the default datacenter datacenter := d.agent.config.Datacenter @@ -633,19 +632,26 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d } } + invalid := func() bool { + d.logger.Warn("QName invalid", "qname", qName) + d.addSOA(cfg, resp) + resp.SetRcode(req, dns.RcodeNameError) + return true + } + if queryKind == "" { - goto INVALID + return invalid() } switch queryKind { case "service": n := len(queryParts) if n < 1 { - goto INVALID + return invalid() } if !d.parseDatacenterAndEnterpriseMeta(querySuffixes, cfg, &datacenter, &entMeta) { - goto INVALID + return invalid() } lookup := serviceLookup{ @@ -689,11 +695,11 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d } case "connect": if len(queryParts) < 1 { - goto INVALID + return invalid() } if !d.parseDatacenterAndEnterpriseMeta(querySuffixes, cfg, &datacenter, &entMeta) { - goto INVALID + return invalid() } lookup := serviceLookup{ @@ -709,11 +715,11 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d d.serviceLookup(cfg, lookup, req, resp) case "ingress": if len(queryParts) < 1 { - goto INVALID + return invalid() } if !d.parseDatacenterAndEnterpriseMeta(querySuffixes, cfg, &datacenter, &entMeta) { - goto INVALID + return invalid() } lookup := serviceLookup{ @@ -729,11 +735,11 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d d.serviceLookup(cfg, lookup, req, resp) case "node": if len(queryParts) < 1 { - goto INVALID + return invalid() } if !d.parseDatacenter(querySuffixes, &datacenter) { - goto INVALID + return invalid() } // Allow a "." in the node name, just join all the parts @@ -742,22 +748,22 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d case "query": // ensure we have a query name if len(queryParts) < 1 { - goto INVALID + return invalid() } if !d.parseDatacenter(querySuffixes, &datacenter) { - goto INVALID + return invalid() } // Allow a "." in the query name, just join all the parts. query := strings.Join(queryParts, ".") - ecsGlobal = false d.preparedQueryLookup(cfg, network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel) + return false case "addr": //
.addr.. - addr must be the second label, datacenter is optional if len(queryParts) != 1 { - goto INVALID + return invalid() } switch len(queryParts[0]) / 2 { @@ -765,7 +771,7 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d case 4: ip, err := hex.DecodeString(queryParts[0]) if err != nil { - goto INVALID + return invalid() } resp.Answer = append(resp.Answer, &dns.A{ @@ -781,7 +787,7 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d case 16: ip, err := hex.DecodeString(queryParts[0]) if err != nil { - goto INVALID + return invalid() } resp.Answer = append(resp.Answer, &dns.AAAA{ @@ -795,14 +801,7 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d }) } } - // early return without error - return - -INVALID: - d.logger.Warn("QName invalid", "qname", qName) - d.addSOA(cfg, resp) - resp.SetRcode(req, dns.RcodeNameError) - return + return true } func (d *DNSServer) trimDomain(query string) string {