mirror of https://github.com/hashicorp/consul
feat(v2dns): add grpc DNS support (#20296)
parent
6d9e8fdd05
commit
97ae244d8a
|
@ -1139,16 +1139,19 @@ func (a *Agent) listenAndServeV2DNS() error {
|
|||
}(addr)
|
||||
}
|
||||
|
||||
// TODO(v2-dns): implement a new grpcDNS proxy that takes in the new Router object.
|
||||
//s, _ := dns.NewServer(cfg)
|
||||
//
|
||||
//grpcDNS.NewServer(grpcDNS.Config{
|
||||
// Logger: a.logger.Named("grpc-api.dns"),
|
||||
// DNSServeMux: s.mux,
|
||||
// LocalAddr: grpcDNS.LocalAddr{IP: net.IPv4(127, 0, 0, 1), Port: a.config.GRPCPort},
|
||||
//}).Register(a.externalGRPCServer)
|
||||
//
|
||||
//a.dnsServers = append(a.dnsServers, s)
|
||||
s, err := dns.NewServer(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create grpc dns server: %w", err)
|
||||
}
|
||||
|
||||
// Create a v2 compatible grpc dns server
|
||||
grpcDNS.NewServerV2(grpcDNS.ConfigV2{
|
||||
Logger: a.logger.Named("grpc-api.dns"),
|
||||
DNSRouter: s.Router,
|
||||
TokenFunc: a.getTokenFunc(),
|
||||
}).Register(a.externalGRPCServer)
|
||||
|
||||
a.dnsServers = append(a.dnsServers, s)
|
||||
|
||||
// wait for servers to be up
|
||||
timeout := time.After(time.Second)
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
// Code generated by mockery v2.37.1. DO NOT EDIT.
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
config "github.com/hashicorp/consul/agent/config"
|
||||
discovery "github.com/hashicorp/consul/agent/discovery"
|
||||
|
||||
miekgdns "github.com/miekg/dns"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
net "net"
|
||||
)
|
||||
|
||||
// MockDNSRouter is an autogenerated mock type for the DNSRouter type
|
||||
type MockDNSRouter struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// HandleRequest provides a mock function with given fields: req, reqCtx, remoteAddress
|
||||
func (_m *MockDNSRouter) HandleRequest(req *miekgdns.Msg, reqCtx discovery.Context, remoteAddress net.Addr) *miekgdns.Msg {
|
||||
ret := _m.Called(req, reqCtx, remoteAddress)
|
||||
|
||||
var r0 *miekgdns.Msg
|
||||
if rf, ok := ret.Get(0).(func(*miekgdns.Msg, discovery.Context, net.Addr) *miekgdns.Msg); ok {
|
||||
r0 = rf(req, reqCtx, remoteAddress)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*miekgdns.Msg)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// ReloadConfig provides a mock function with given fields: newCfg
|
||||
func (_m *MockDNSRouter) ReloadConfig(newCfg *config.RuntimeConfig) error {
|
||||
ret := _m.Called(newCfg)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(*config.RuntimeConfig) error); ok {
|
||||
r0 = rf(newCfg)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// ServeDNS provides a mock function with given fields: w, req
|
||||
func (_m *MockDNSRouter) ServeDNS(w miekgdns.ResponseWriter, req *miekgdns.Msg) {
|
||||
_m.Called(w, req)
|
||||
}
|
||||
|
||||
// NewMockDNSRouter creates a new instance of MockDNSRouter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockDNSRouter(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *MockDNSRouter {
|
||||
mock := &MockDNSRouter{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
|
@ -104,6 +104,7 @@ type Router struct {
|
|||
}
|
||||
|
||||
var _ = dns.Handler(&Router{})
|
||||
var _ = DNSRouter(&Router{})
|
||||
|
||||
func NewRouter(cfg Config) (*Router, error) {
|
||||
// Make sure domains are FQDN, make them case-insensitive for DNSRequestRouter
|
||||
|
|
|
@ -5,20 +5,31 @@ package dns
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/config"
|
||||
"github.com/hashicorp/consul/agent/discovery"
|
||||
"github.com/hashicorp/consul/logging"
|
||||
)
|
||||
|
||||
// DNSRouter is a mock for Router that can be used for testing.
|
||||
//
|
||||
//go:generate mockery --name DNSRouter --inpackage
|
||||
type DNSRouter interface {
|
||||
HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAddress net.Addr) *dns.Msg
|
||||
ServeDNS(w dns.ResponseWriter, req *dns.Msg)
|
||||
ReloadConfig(newCfg *config.RuntimeConfig) error
|
||||
}
|
||||
|
||||
// Server is used to expose service discovery queries using a DNS interface.
|
||||
// It implements the agent.dnsServer interface.
|
||||
type Server struct {
|
||||
*dns.Server // Used for setting up listeners
|
||||
Router *Router // Used to routes and parse DNS requests
|
||||
*dns.Server // Used for setting up listeners
|
||||
Router DNSRouter // Used to routes and parse DNS requests
|
||||
|
||||
logger hclog.Logger
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ func helloServer(w dns.ResponseWriter, req *dns.Msg) {
|
|||
w.WriteMsg(m)
|
||||
}
|
||||
|
||||
func testClient(t *testing.T, server *Server) pbdns.DNSServiceClient {
|
||||
func testClient(t *testing.T, server testutils.GRPCService) pbdns.DNSServiceClient {
|
||||
t.Helper()
|
||||
|
||||
addr := testutils.RunTestServer(t, server)
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/miekg/dns"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/hashicorp/consul/agent/discovery"
|
||||
agentdns "github.com/hashicorp/consul/agent/dns"
|
||||
"github.com/hashicorp/consul/proto-public/pbdns"
|
||||
)
|
||||
|
||||
type ConfigV2 struct {
|
||||
DNSRouter agentdns.DNSRouter
|
||||
Logger hclog.Logger
|
||||
TokenFunc func() string
|
||||
}
|
||||
|
||||
var _ pbdns.DNSServiceServer = (*ServerV2)(nil)
|
||||
|
||||
// ServerV2 is a gRPC server that implements pbdns.DNSServiceServer.
|
||||
// It is compatible with the refactored V2 DNS server and suitable for
|
||||
// passing additional metadata along the grpc connection to catalog queries.
|
||||
type ServerV2 struct {
|
||||
ConfigV2
|
||||
}
|
||||
|
||||
func NewServerV2(cfg ConfigV2) *ServerV2 {
|
||||
return &ServerV2{cfg}
|
||||
}
|
||||
|
||||
func (s *ServerV2) Register(registrar grpc.ServiceRegistrar) {
|
||||
pbdns.RegisterDNSServiceServer(registrar, s)
|
||||
}
|
||||
|
||||
// Query is a gRPC endpoint that will serve dns requests. It will be consumed primarily by the
|
||||
// consul dataplane to proxy dns requests to consul.
|
||||
func (s *ServerV2) Query(ctx context.Context, req *pbdns.QueryRequest) (*pbdns.QueryResponse, error) {
|
||||
pr, ok := peer.FromContext(ctx)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("error retrieving peer information from context")
|
||||
}
|
||||
|
||||
var remote net.Addr
|
||||
// We do this so that we switch to udp/tcp when handling the request since it will be proxied
|
||||
// through consul through gRPC, and we need to 'fake' the protocol so that the message is trimmed
|
||||
// according to whether it is UDP or TCP.
|
||||
switch req.GetProtocol() {
|
||||
case pbdns.Protocol_PROTOCOL_TCP:
|
||||
remote = pr.Addr
|
||||
case pbdns.Protocol_PROTOCOL_UDP:
|
||||
remoteAddr := pr.Addr.(*net.TCPAddr)
|
||||
remote = &net.UDPAddr{IP: remoteAddr.IP, Port: remoteAddr.Port}
|
||||
default:
|
||||
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("error protocol type not set: %v", req.GetProtocol()))
|
||||
}
|
||||
|
||||
msg := &dns.Msg{}
|
||||
err := msg.Unpack(req.Msg)
|
||||
if err != nil {
|
||||
s.Logger.Error("error unpacking message", "err", err)
|
||||
return nil, status.Error(codes.Internal, fmt.Sprintf("failure decoding dns request: %s", err.Error()))
|
||||
}
|
||||
|
||||
// TODO (v2-dns): parse token and other context metadata from the grpc request/metadata
|
||||
reqCtx := discovery.Context{
|
||||
Token: s.TokenFunc(),
|
||||
}
|
||||
|
||||
resp := s.DNSRouter.HandleRequest(msg, reqCtx, remote)
|
||||
data, err := resp.Pack()
|
||||
if err != nil {
|
||||
s.Logger.Error("error packing message", "err", err)
|
||||
return nil, status.Error(codes.Internal, fmt.Sprintf("failure encoding dns request: %s", err.Error()))
|
||||
}
|
||||
|
||||
queryResponse := &pbdns.QueryResponse{Msg: data}
|
||||
return queryResponse, nil
|
||||
}
|
|
@ -0,0 +1,130 @@
|
|||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
agentdns "github.com/hashicorp/consul/agent/dns"
|
||||
"github.com/hashicorp/consul/proto-public/pbdns"
|
||||
)
|
||||
|
||||
func basicResponse() *dns.Msg {
|
||||
return &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Opcode: dns.OpcodeQuery,
|
||||
Response: true,
|
||||
Authoritative: true,
|
||||
},
|
||||
Compress: true,
|
||||
Question: []dns.Question{
|
||||
{
|
||||
Name: "abc.com.",
|
||||
Qtype: dns.TypeANY,
|
||||
Qclass: dns.ClassINET,
|
||||
},
|
||||
},
|
||||
Extra: []dns.RR{
|
||||
&dns.TXT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "abc.com.",
|
||||
Rrtype: dns.TypeTXT,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
Txt: txtRR,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DNSTestSuite) TestProxy_V2Success() {
|
||||
|
||||
testCases := map[string]struct {
|
||||
question string
|
||||
configureRouter func(router *agentdns.MockDNSRouter)
|
||||
clientQuery func(qR *pbdns.QueryRequest)
|
||||
expectedErr error
|
||||
}{
|
||||
|
||||
"happy path udp": {
|
||||
question: "abc.com.",
|
||||
configureRouter: func(router *agentdns.MockDNSRouter) {
|
||||
router.On("HandleRequest", mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(basicResponse(), nil)
|
||||
},
|
||||
clientQuery: func(qR *pbdns.QueryRequest) {
|
||||
qR.Protocol = pbdns.Protocol_PROTOCOL_UDP
|
||||
},
|
||||
},
|
||||
"happy path tcp": {
|
||||
question: "abc.com.",
|
||||
configureRouter: func(router *agentdns.MockDNSRouter) {
|
||||
router.On("HandleRequest", mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(basicResponse(), nil)
|
||||
},
|
||||
clientQuery: func(qR *pbdns.QueryRequest) {
|
||||
qR.Protocol = pbdns.Protocol_PROTOCOL_TCP
|
||||
},
|
||||
},
|
||||
"No protocol set": {
|
||||
question: "abc.com.",
|
||||
clientQuery: func(qR *pbdns.QueryRequest) {},
|
||||
expectedErr: errors.New("error protocol type not set: PROTOCOL_UNSET_UNSPECIFIED"),
|
||||
},
|
||||
"Invalid question": {
|
||||
question: "notvalid",
|
||||
clientQuery: func(qR *pbdns.QueryRequest) {
|
||||
qR.Protocol = pbdns.Protocol_PROTOCOL_UDP
|
||||
},
|
||||
expectedErr: errors.New("failure decoding dns request"),
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
s.Run(name, func() {
|
||||
router := agentdns.NewMockDNSRouter(s.T())
|
||||
|
||||
if tc.configureRouter != nil {
|
||||
tc.configureRouter(router)
|
||||
}
|
||||
|
||||
server := NewServerV2(ConfigV2{
|
||||
Logger: hclog.Default(),
|
||||
DNSRouter: router,
|
||||
TokenFunc: func() string { return "" },
|
||||
})
|
||||
|
||||
client := testClient(s.T(), server)
|
||||
|
||||
req := dns.Msg{}
|
||||
req.SetQuestion(tc.question, dns.TypeA)
|
||||
|
||||
bytes, _ := req.Pack()
|
||||
|
||||
clientReq := &pbdns.QueryRequest{Msg: bytes}
|
||||
tc.clientQuery(clientReq)
|
||||
clientResp, err := client.Query(context.Background(), clientReq)
|
||||
if tc.expectedErr != nil {
|
||||
s.Require().Error(err, "no errror calling gRPC endpoint")
|
||||
s.Require().ErrorContains(err, tc.expectedErr.Error())
|
||||
} else {
|
||||
s.Require().NoError(err, "error calling gRPC endpoint")
|
||||
|
||||
resp := clientResp.GetMsg()
|
||||
var dnsResp dns.Msg
|
||||
|
||||
err = dnsResp.Unpack(resp)
|
||||
s.Require().NoError(err, "error unpacking dns response")
|
||||
rr := dnsResp.Extra[0].(*dns.TXT)
|
||||
s.Require().EqualValues(rr.Txt, txtRR)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue