mirror of https://github.com/hashicorp/consul
Dan Stough
10 months ago
committed by
GitHub
7 changed files with 315 additions and 13 deletions
@ -0,0 +1,68 @@
|
||||
// Code generated by mockery v2.37.1. DO NOT EDIT.
|
||||
|
||||
package dns |
||||
|
||||
import ( |
||||
config "github.com/hashicorp/consul/agent/config" |
||||
discovery "github.com/hashicorp/consul/agent/discovery" |
||||
|
||||
miekgdns "github.com/miekg/dns" |
||||
|
||||
mock "github.com/stretchr/testify/mock" |
||||
|
||||
net "net" |
||||
) |
||||
|
||||
// MockDNSRouter is an autogenerated mock type for the DNSRouter type
|
||||
type MockDNSRouter struct { |
||||
mock.Mock |
||||
} |
||||
|
||||
// HandleRequest provides a mock function with given fields: req, reqCtx, remoteAddress
|
||||
func (_m *MockDNSRouter) HandleRequest(req *miekgdns.Msg, reqCtx discovery.Context, remoteAddress net.Addr) *miekgdns.Msg { |
||||
ret := _m.Called(req, reqCtx, remoteAddress) |
||||
|
||||
var r0 *miekgdns.Msg |
||||
if rf, ok := ret.Get(0).(func(*miekgdns.Msg, discovery.Context, net.Addr) *miekgdns.Msg); ok { |
||||
r0 = rf(req, reqCtx, remoteAddress) |
||||
} else { |
||||
if ret.Get(0) != nil { |
||||
r0 = ret.Get(0).(*miekgdns.Msg) |
||||
} |
||||
} |
||||
|
||||
return r0 |
||||
} |
||||
|
||||
// ReloadConfig provides a mock function with given fields: newCfg
|
||||
func (_m *MockDNSRouter) ReloadConfig(newCfg *config.RuntimeConfig) error { |
||||
ret := _m.Called(newCfg) |
||||
|
||||
var r0 error |
||||
if rf, ok := ret.Get(0).(func(*config.RuntimeConfig) error); ok { |
||||
r0 = rf(newCfg) |
||||
} else { |
||||
r0 = ret.Error(0) |
||||
} |
||||
|
||||
return r0 |
||||
} |
||||
|
||||
// ServeDNS provides a mock function with given fields: w, req
|
||||
func (_m *MockDNSRouter) ServeDNS(w miekgdns.ResponseWriter, req *miekgdns.Msg) { |
||||
_m.Called(w, req) |
||||
} |
||||
|
||||
// NewMockDNSRouter creates a new instance of MockDNSRouter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockDNSRouter(t interface { |
||||
mock.TestingT |
||||
Cleanup(func()) |
||||
}) *MockDNSRouter { |
||||
mock := &MockDNSRouter{} |
||||
mock.Mock.Test(t) |
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) }) |
||||
|
||||
return mock |
||||
} |
@ -0,0 +1,89 @@
|
||||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package dns |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"net" |
||||
|
||||
"github.com/hashicorp/go-hclog" |
||||
"github.com/miekg/dns" |
||||
"google.golang.org/grpc" |
||||
"google.golang.org/grpc/codes" |
||||
"google.golang.org/grpc/peer" |
||||
"google.golang.org/grpc/status" |
||||
|
||||
"github.com/hashicorp/consul/agent/discovery" |
||||
agentdns "github.com/hashicorp/consul/agent/dns" |
||||
"github.com/hashicorp/consul/proto-public/pbdns" |
||||
) |
||||
|
||||
type ConfigV2 struct { |
||||
DNSRouter agentdns.DNSRouter |
||||
Logger hclog.Logger |
||||
TokenFunc func() string |
||||
} |
||||
|
||||
var _ pbdns.DNSServiceServer = (*ServerV2)(nil) |
||||
|
||||
// ServerV2 is a gRPC server that implements pbdns.DNSServiceServer.
|
||||
// It is compatible with the refactored V2 DNS server and suitable for
|
||||
// passing additional metadata along the grpc connection to catalog queries.
|
||||
type ServerV2 struct { |
||||
ConfigV2 |
||||
} |
||||
|
||||
func NewServerV2(cfg ConfigV2) *ServerV2 { |
||||
return &ServerV2{cfg} |
||||
} |
||||
|
||||
func (s *ServerV2) Register(registrar grpc.ServiceRegistrar) { |
||||
pbdns.RegisterDNSServiceServer(registrar, s) |
||||
} |
||||
|
||||
// Query is a gRPC endpoint that will serve dns requests. It will be consumed primarily by the
|
||||
// consul dataplane to proxy dns requests to consul.
|
||||
func (s *ServerV2) Query(ctx context.Context, req *pbdns.QueryRequest) (*pbdns.QueryResponse, error) { |
||||
pr, ok := peer.FromContext(ctx) |
||||
if !ok { |
||||
return nil, fmt.Errorf("error retrieving peer information from context") |
||||
} |
||||
|
||||
var remote net.Addr |
||||
// We do this so that we switch to udp/tcp when handling the request since it will be proxied
|
||||
// through consul through gRPC, and we need to 'fake' the protocol so that the message is trimmed
|
||||
// according to whether it is UDP or TCP.
|
||||
switch req.GetProtocol() { |
||||
case pbdns.Protocol_PROTOCOL_TCP: |
||||
remote = pr.Addr |
||||
case pbdns.Protocol_PROTOCOL_UDP: |
||||
remoteAddr := pr.Addr.(*net.TCPAddr) |
||||
remote = &net.UDPAddr{IP: remoteAddr.IP, Port: remoteAddr.Port} |
||||
default: |
||||
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("error protocol type not set: %v", req.GetProtocol())) |
||||
} |
||||
|
||||
msg := &dns.Msg{} |
||||
err := msg.Unpack(req.Msg) |
||||
if err != nil { |
||||
s.Logger.Error("error unpacking message", "err", err) |
||||
return nil, status.Error(codes.Internal, fmt.Sprintf("failure decoding dns request: %s", err.Error())) |
||||
} |
||||
|
||||
// TODO (v2-dns): parse token and other context metadata from the grpc request/metadata
|
||||
reqCtx := discovery.Context{ |
||||
Token: s.TokenFunc(), |
||||
} |
||||
|
||||
resp := s.DNSRouter.HandleRequest(msg, reqCtx, remote) |
||||
data, err := resp.Pack() |
||||
if err != nil { |
||||
s.Logger.Error("error packing message", "err", err) |
||||
return nil, status.Error(codes.Internal, fmt.Sprintf("failure encoding dns request: %s", err.Error())) |
||||
} |
||||
|
||||
queryResponse := &pbdns.QueryResponse{Msg: data} |
||||
return queryResponse, nil |
||||
} |
@ -0,0 +1,130 @@
|
||||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package dns |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
|
||||
"github.com/hashicorp/go-hclog" |
||||
"github.com/miekg/dns" |
||||
"github.com/stretchr/testify/mock" |
||||
|
||||
agentdns "github.com/hashicorp/consul/agent/dns" |
||||
"github.com/hashicorp/consul/proto-public/pbdns" |
||||
) |
||||
|
||||
func basicResponse() *dns.Msg { |
||||
return &dns.Msg{ |
||||
MsgHdr: dns.MsgHdr{ |
||||
Opcode: dns.OpcodeQuery, |
||||
Response: true, |
||||
Authoritative: true, |
||||
}, |
||||
Compress: true, |
||||
Question: []dns.Question{ |
||||
{ |
||||
Name: "abc.com.", |
||||
Qtype: dns.TypeANY, |
||||
Qclass: dns.ClassINET, |
||||
}, |
||||
}, |
||||
Extra: []dns.RR{ |
||||
&dns.TXT{ |
||||
Hdr: dns.RR_Header{ |
||||
Name: "abc.com.", |
||||
Rrtype: dns.TypeTXT, |
||||
Class: dns.ClassINET, |
||||
Ttl: 0, |
||||
}, |
||||
Txt: txtRR, |
||||
}, |
||||
}, |
||||
} |
||||
} |
||||
|
||||
func (s *DNSTestSuite) TestProxy_V2Success() { |
||||
|
||||
testCases := map[string]struct { |
||||
question string |
||||
configureRouter func(router *agentdns.MockDNSRouter) |
||||
clientQuery func(qR *pbdns.QueryRequest) |
||||
expectedErr error |
||||
}{ |
||||
|
||||
"happy path udp": { |
||||
question: "abc.com.", |
||||
configureRouter: func(router *agentdns.MockDNSRouter) { |
||||
router.On("HandleRequest", mock.Anything, mock.Anything, mock.Anything). |
||||
Return(basicResponse(), nil) |
||||
}, |
||||
clientQuery: func(qR *pbdns.QueryRequest) { |
||||
qR.Protocol = pbdns.Protocol_PROTOCOL_UDP |
||||
}, |
||||
}, |
||||
"happy path tcp": { |
||||
question: "abc.com.", |
||||
configureRouter: func(router *agentdns.MockDNSRouter) { |
||||
router.On("HandleRequest", mock.Anything, mock.Anything, mock.Anything). |
||||
Return(basicResponse(), nil) |
||||
}, |
||||
clientQuery: func(qR *pbdns.QueryRequest) { |
||||
qR.Protocol = pbdns.Protocol_PROTOCOL_TCP |
||||
}, |
||||
}, |
||||
"No protocol set": { |
||||
question: "abc.com.", |
||||
clientQuery: func(qR *pbdns.QueryRequest) {}, |
||||
expectedErr: errors.New("error protocol type not set: PROTOCOL_UNSET_UNSPECIFIED"), |
||||
}, |
||||
"Invalid question": { |
||||
question: "notvalid", |
||||
clientQuery: func(qR *pbdns.QueryRequest) { |
||||
qR.Protocol = pbdns.Protocol_PROTOCOL_UDP |
||||
}, |
||||
expectedErr: errors.New("failure decoding dns request"), |
||||
}, |
||||
} |
||||
|
||||
for name, tc := range testCases { |
||||
s.Run(name, func() { |
||||
router := agentdns.NewMockDNSRouter(s.T()) |
||||
|
||||
if tc.configureRouter != nil { |
||||
tc.configureRouter(router) |
||||
} |
||||
|
||||
server := NewServerV2(ConfigV2{ |
||||
Logger: hclog.Default(), |
||||
DNSRouter: router, |
||||
TokenFunc: func() string { return "" }, |
||||
}) |
||||
|
||||
client := testClient(s.T(), server) |
||||
|
||||
req := dns.Msg{} |
||||
req.SetQuestion(tc.question, dns.TypeA) |
||||
|
||||
bytes, _ := req.Pack() |
||||
|
||||
clientReq := &pbdns.QueryRequest{Msg: bytes} |
||||
tc.clientQuery(clientReq) |
||||
clientResp, err := client.Query(context.Background(), clientReq) |
||||
if tc.expectedErr != nil { |
||||
s.Require().Error(err, "no errror calling gRPC endpoint") |
||||
s.Require().ErrorContains(err, tc.expectedErr.Error()) |
||||
} else { |
||||
s.Require().NoError(err, "error calling gRPC endpoint") |
||||
|
||||
resp := clientResp.GetMsg() |
||||
var dnsResp dns.Msg |
||||
|
||||
err = dnsResp.Unpack(resp) |
||||
s.Require().NoError(err, "error unpacking dns response") |
||||
rr := dnsResp.Extra[0].(*dns.TXT) |
||||
s.Require().EqualValues(rr.Txt, txtRR) |
||||
} |
||||
}) |
||||
} |
||||
} |
Loading…
Reference in new issue