diff --git a/agent/consul/catalog_endpoint.go b/agent/consul/catalog_endpoint.go index 52ba5fb1b7..0c1cbe3de7 100644 --- a/agent/consul/catalog_endpoint.go +++ b/agent/consul/catalog_endpoint.go @@ -240,7 +240,7 @@ func (c *Catalog) ServiceNodes(args *structs.ServiceSpecificRequest, reply *stru } // Verify the arguments - if args.ServiceName == "" { + if args.ServiceName == "" && args.ServiceAddress == "" { return fmt.Errorf("Must provide service name") } @@ -256,6 +256,9 @@ func (c *Catalog) ServiceNodes(args *structs.ServiceSpecificRequest, reply *stru } else { index, services, err = state.ServiceNodes(ws, args.ServiceName) } + if args.ServiceAddress != "" { + index, services, err = state.ServiceAddressNodes(ws, args.ServiceAddress) + } if err != nil { return err } diff --git a/agent/consul/state/catalog.go b/agent/consul/state/catalog.go index 3c18f9fc9f..2a81c10713 100644 --- a/agent/consul/state/catalog.go +++ b/agent/consul/state/catalog.go @@ -855,6 +855,36 @@ func serviceTagFilter(sn *structs.ServiceNode, tag string) bool { return true } +// ServiceAddressNodes returns the nodes associated with a given service, filtering +// out services that don't match the given serviceAddress +func (s *Store) ServiceAddressNodes(ws memdb.WatchSet, address string) (uint64, structs.ServiceNodes, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // List all the services. + services, err := tx.Get("services", "id") + if err != nil { + return 0, nil, fmt.Errorf("failed service lookup: %s", err) + } + ws.Add(services.WatchCh()) + + // Gather all the services and apply the tag filter. + var results structs.ServiceNodes + for service := services.Next(); service != nil; service = services.Next() { + svc := service.(*structs.ServiceNode) + if svc.ServiceAddress == address { + results = append(results, svc) + } + } + + // Fill in the node details. + results, err = s.parseServiceNodes(tx, ws, results) + if err != nil { + return 0, nil, fmt.Errorf("failed parsing service nodes: %s", err) + } + return 0, results, nil +} + // parseServiceNodes iterates over a services query and fills in the node details, // returning a ServiceNodes slice. func (s *Store) parseServiceNodes(tx *memdb.Txn, ws memdb.WatchSet, services structs.ServiceNodes) (structs.ServiceNodes, error) { diff --git a/agent/dns.go b/agent/dns.go index f1c0d8bda7..7e4b816f96 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -12,6 +12,7 @@ import ( "regexp" "github.com/armon/go-metrics" + "github.com/coredns/coredns/plugin/pkg/dnsutil" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/consul" "github.com/hashicorp/consul/agent/structs" @@ -207,6 +208,31 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) { } } + // lookup the service address + serviceAddress := dnsutil.ExtractAddressFromReverse(qName) + sargs := structs.ServiceSpecificRequest{ + Datacenter: datacenter, + QueryOptions: structs.QueryOptions{ + Token: d.agent.tokens.UserToken(), + AllowStale: d.config.AllowStale, + }, + ServiceAddress: serviceAddress, + } + + var sout structs.IndexedServiceNodes + if err := d.agent.RPC("Catalog.ServiceNodes", &sargs, &sout); err == nil { + for _, n := range sout.ServiceNodes { + if n.ServiceAddress == serviceAddress { + ptr := &dns.PTR{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: 0}, + Ptr: fmt.Sprintf("%s.service.%s", n.ServiceName, d.domain), + } + m.Answer = append(m.Answer, ptr) + break + } + } + } + // nothing found locally, recurse if len(m.Answer) == 0 { d.handleRecurse(resp, req) diff --git a/agent/structs/structs.go b/agent/structs/structs.go index c5b5942136..77075b3e32 100644 --- a/agent/structs/structs.go +++ b/agent/structs/structs.go @@ -279,6 +279,7 @@ type ServiceSpecificRequest struct { NodeMetaFilters map[string]string ServiceName string ServiceTag string + ServiceAddress string TagFilter bool // Controls tag filtering Source QuerySource QueryOptions