Browse Source

feat(v2dns): add grpc DNS support (#20296)

pull/20310/head
Dan Stough 10 months ago committed by GitHub
parent
commit
97ae244d8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 23
      agent/agent.go
  2. 68
      agent/dns/mock_DNSRouter.go
  3. 1
      agent/dns/router.go
  4. 15
      agent/dns/server.go
  5. 2
      agent/grpc-external/services/dns/server_test.go
  6. 89
      agent/grpc-external/services/dns/server_v2.go
  7. 130
      agent/grpc-external/services/dns/server_v2_test.go

23
agent/agent.go

@ -1139,16 +1139,19 @@ func (a *Agent) listenAndServeV2DNS() error {
}(addr)
}
// TODO(v2-dns): implement a new grpcDNS proxy that takes in the new Router object.
//s, _ := dns.NewServer(cfg)
//
//grpcDNS.NewServer(grpcDNS.Config{
// Logger: a.logger.Named("grpc-api.dns"),
// DNSServeMux: s.mux,
// LocalAddr: grpcDNS.LocalAddr{IP: net.IPv4(127, 0, 0, 1), Port: a.config.GRPCPort},
//}).Register(a.externalGRPCServer)
//
//a.dnsServers = append(a.dnsServers, s)
s, err := dns.NewServer(cfg)
if err != nil {
return fmt.Errorf("failed to create grpc dns server: %w", err)
}
// Create a v2 compatible grpc dns server
grpcDNS.NewServerV2(grpcDNS.ConfigV2{
Logger: a.logger.Named("grpc-api.dns"),
DNSRouter: s.Router,
TokenFunc: a.getTokenFunc(),
}).Register(a.externalGRPCServer)
a.dnsServers = append(a.dnsServers, s)
// wait for servers to be up
timeout := time.After(time.Second)

68
agent/dns/mock_DNSRouter.go

@ -0,0 +1,68 @@
// Code generated by mockery v2.37.1. DO NOT EDIT.
package dns
import (
config "github.com/hashicorp/consul/agent/config"
discovery "github.com/hashicorp/consul/agent/discovery"
miekgdns "github.com/miekg/dns"
mock "github.com/stretchr/testify/mock"
net "net"
)
// MockDNSRouter is an autogenerated mock type for the DNSRouter type
type MockDNSRouter struct {
mock.Mock
}
// HandleRequest provides a mock function with given fields: req, reqCtx, remoteAddress
func (_m *MockDNSRouter) HandleRequest(req *miekgdns.Msg, reqCtx discovery.Context, remoteAddress net.Addr) *miekgdns.Msg {
ret := _m.Called(req, reqCtx, remoteAddress)
var r0 *miekgdns.Msg
if rf, ok := ret.Get(0).(func(*miekgdns.Msg, discovery.Context, net.Addr) *miekgdns.Msg); ok {
r0 = rf(req, reqCtx, remoteAddress)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*miekgdns.Msg)
}
}
return r0
}
// ReloadConfig provides a mock function with given fields: newCfg
func (_m *MockDNSRouter) ReloadConfig(newCfg *config.RuntimeConfig) error {
ret := _m.Called(newCfg)
var r0 error
if rf, ok := ret.Get(0).(func(*config.RuntimeConfig) error); ok {
r0 = rf(newCfg)
} else {
r0 = ret.Error(0)
}
return r0
}
// ServeDNS provides a mock function with given fields: w, req
func (_m *MockDNSRouter) ServeDNS(w miekgdns.ResponseWriter, req *miekgdns.Msg) {
_m.Called(w, req)
}
// NewMockDNSRouter creates a new instance of MockDNSRouter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockDNSRouter(t interface {
mock.TestingT
Cleanup(func())
}) *MockDNSRouter {
mock := &MockDNSRouter{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

1
agent/dns/router.go

@ -104,6 +104,7 @@ type Router struct {
}
var _ = dns.Handler(&Router{})
var _ = DNSRouter(&Router{})
func NewRouter(cfg Config) (*Router, error) {
// Make sure domains are FQDN, make them case-insensitive for DNSRequestRouter

15
agent/dns/server.go

@ -5,20 +5,31 @@ package dns
import (
"fmt"
"net"
"github.com/hashicorp/go-hclog"
"github.com/miekg/dns"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/discovery"
"github.com/hashicorp/consul/logging"
)
// DNSRouter is a mock for Router that can be used for testing.
//
//go:generate mockery --name DNSRouter --inpackage
type DNSRouter interface {
HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAddress net.Addr) *dns.Msg
ServeDNS(w dns.ResponseWriter, req *dns.Msg)
ReloadConfig(newCfg *config.RuntimeConfig) error
}
// Server is used to expose service discovery queries using a DNS interface.
// It implements the agent.dnsServer interface.
type Server struct {
*dns.Server // Used for setting up listeners
Router *Router // Used to routes and parse DNS requests
*dns.Server // Used for setting up listeners
Router DNSRouter // Used to routes and parse DNS requests
logger hclog.Logger
}

2
agent/grpc-external/services/dns/server_test.go vendored

@ -33,7 +33,7 @@ func helloServer(w dns.ResponseWriter, req *dns.Msg) {
w.WriteMsg(m)
}
func testClient(t *testing.T, server *Server) pbdns.DNSServiceClient {
func testClient(t *testing.T, server testutils.GRPCService) pbdns.DNSServiceClient {
t.Helper()
addr := testutils.RunTestServer(t, server)

89
agent/grpc-external/services/dns/server_v2.go vendored

@ -0,0 +1,89 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"context"
"fmt"
"net"
"github.com/hashicorp/go-hclog"
"github.com/miekg/dns"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"github.com/hashicorp/consul/agent/discovery"
agentdns "github.com/hashicorp/consul/agent/dns"
"github.com/hashicorp/consul/proto-public/pbdns"
)
type ConfigV2 struct {
DNSRouter agentdns.DNSRouter
Logger hclog.Logger
TokenFunc func() string
}
var _ pbdns.DNSServiceServer = (*ServerV2)(nil)
// ServerV2 is a gRPC server that implements pbdns.DNSServiceServer.
// It is compatible with the refactored V2 DNS server and suitable for
// passing additional metadata along the grpc connection to catalog queries.
type ServerV2 struct {
ConfigV2
}
func NewServerV2(cfg ConfigV2) *ServerV2 {
return &ServerV2{cfg}
}
func (s *ServerV2) Register(registrar grpc.ServiceRegistrar) {
pbdns.RegisterDNSServiceServer(registrar, s)
}
// Query is a gRPC endpoint that will serve dns requests. It will be consumed primarily by the
// consul dataplane to proxy dns requests to consul.
func (s *ServerV2) Query(ctx context.Context, req *pbdns.QueryRequest) (*pbdns.QueryResponse, error) {
pr, ok := peer.FromContext(ctx)
if !ok {
return nil, fmt.Errorf("error retrieving peer information from context")
}
var remote net.Addr
// We do this so that we switch to udp/tcp when handling the request since it will be proxied
// through consul through gRPC, and we need to 'fake' the protocol so that the message is trimmed
// according to whether it is UDP or TCP.
switch req.GetProtocol() {
case pbdns.Protocol_PROTOCOL_TCP:
remote = pr.Addr
case pbdns.Protocol_PROTOCOL_UDP:
remoteAddr := pr.Addr.(*net.TCPAddr)
remote = &net.UDPAddr{IP: remoteAddr.IP, Port: remoteAddr.Port}
default:
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("error protocol type not set: %v", req.GetProtocol()))
}
msg := &dns.Msg{}
err := msg.Unpack(req.Msg)
if err != nil {
s.Logger.Error("error unpacking message", "err", err)
return nil, status.Error(codes.Internal, fmt.Sprintf("failure decoding dns request: %s", err.Error()))
}
// TODO (v2-dns): parse token and other context metadata from the grpc request/metadata
reqCtx := discovery.Context{
Token: s.TokenFunc(),
}
resp := s.DNSRouter.HandleRequest(msg, reqCtx, remote)
data, err := resp.Pack()
if err != nil {
s.Logger.Error("error packing message", "err", err)
return nil, status.Error(codes.Internal, fmt.Sprintf("failure encoding dns request: %s", err.Error()))
}
queryResponse := &pbdns.QueryResponse{Msg: data}
return queryResponse, nil
}

130
agent/grpc-external/services/dns/server_v2_test.go vendored

@ -0,0 +1,130 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package dns
import (
"context"
"errors"
"github.com/hashicorp/go-hclog"
"github.com/miekg/dns"
"github.com/stretchr/testify/mock"
agentdns "github.com/hashicorp/consul/agent/dns"
"github.com/hashicorp/consul/proto-public/pbdns"
)
func basicResponse() *dns.Msg {
return &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
Response: true,
Authoritative: true,
},
Compress: true,
Question: []dns.Question{
{
Name: "abc.com.",
Qtype: dns.TypeANY,
Qclass: dns.ClassINET,
},
},
Extra: []dns.RR{
&dns.TXT{
Hdr: dns.RR_Header{
Name: "abc.com.",
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 0,
},
Txt: txtRR,
},
},
}
}
func (s *DNSTestSuite) TestProxy_V2Success() {
testCases := map[string]struct {
question string
configureRouter func(router *agentdns.MockDNSRouter)
clientQuery func(qR *pbdns.QueryRequest)
expectedErr error
}{
"happy path udp": {
question: "abc.com.",
configureRouter: func(router *agentdns.MockDNSRouter) {
router.On("HandleRequest", mock.Anything, mock.Anything, mock.Anything).
Return(basicResponse(), nil)
},
clientQuery: func(qR *pbdns.QueryRequest) {
qR.Protocol = pbdns.Protocol_PROTOCOL_UDP
},
},
"happy path tcp": {
question: "abc.com.",
configureRouter: func(router *agentdns.MockDNSRouter) {
router.On("HandleRequest", mock.Anything, mock.Anything, mock.Anything).
Return(basicResponse(), nil)
},
clientQuery: func(qR *pbdns.QueryRequest) {
qR.Protocol = pbdns.Protocol_PROTOCOL_TCP
},
},
"No protocol set": {
question: "abc.com.",
clientQuery: func(qR *pbdns.QueryRequest) {},
expectedErr: errors.New("error protocol type not set: PROTOCOL_UNSET_UNSPECIFIED"),
},
"Invalid question": {
question: "notvalid",
clientQuery: func(qR *pbdns.QueryRequest) {
qR.Protocol = pbdns.Protocol_PROTOCOL_UDP
},
expectedErr: errors.New("failure decoding dns request"),
},
}
for name, tc := range testCases {
s.Run(name, func() {
router := agentdns.NewMockDNSRouter(s.T())
if tc.configureRouter != nil {
tc.configureRouter(router)
}
server := NewServerV2(ConfigV2{
Logger: hclog.Default(),
DNSRouter: router,
TokenFunc: func() string { return "" },
})
client := testClient(s.T(), server)
req := dns.Msg{}
req.SetQuestion(tc.question, dns.TypeA)
bytes, _ := req.Pack()
clientReq := &pbdns.QueryRequest{Msg: bytes}
tc.clientQuery(clientReq)
clientResp, err := client.Query(context.Background(), clientReq)
if tc.expectedErr != nil {
s.Require().Error(err, "no errror calling gRPC endpoint")
s.Require().ErrorContains(err, tc.expectedErr.Error())
} else {
s.Require().NoError(err, "error calling gRPC endpoint")
resp := clientResp.GetMsg()
var dnsResp dns.Msg
err = dnsResp.Unpack(resp)
s.Require().NoError(err, "error unpacking dns response")
rr := dnsResp.Extra[0].(*dns.TXT)
s.Require().EqualValues(rr.Txt, txtRR)
}
})
}
}
Loading…
Cancel
Save