mirror of https://github.com/hashicorp/consul
165 lines
4.3 KiB
Go
165 lines
4.3 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package dns
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
|
|
"github.com/hashicorp/go-hclog"
|
|
"github.com/miekg/dns"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/grpc/metadata"
|
|
|
|
agentdns "github.com/hashicorp/consul/agent/dns"
|
|
"github.com/hashicorp/consul/proto-public/pbdns"
|
|
)
|
|
|
|
func basicResponse() *dns.Msg {
|
|
return &dns.Msg{
|
|
MsgHdr: dns.MsgHdr{
|
|
Opcode: dns.OpcodeQuery,
|
|
Response: true,
|
|
Authoritative: true,
|
|
},
|
|
Compress: true,
|
|
Question: []dns.Question{
|
|
{
|
|
Name: "abc.com.",
|
|
Qtype: dns.TypeANY,
|
|
Qclass: dns.ClassINET,
|
|
},
|
|
},
|
|
Extra: []dns.RR{
|
|
&dns.TXT{
|
|
Hdr: dns.RR_Header{
|
|
Name: "abc.com.",
|
|
Rrtype: dns.TypeTXT,
|
|
Class: dns.ClassINET,
|
|
Ttl: 0,
|
|
},
|
|
Txt: txtRR,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s *DNSTestSuite) TestProxy_V2Success() {
|
|
|
|
testCases := map[string]struct {
|
|
question string
|
|
configureRouter func(router *agentdns.MockDNSRouter)
|
|
clientQuery func(qR *pbdns.QueryRequest)
|
|
metadata map[string]string
|
|
expectedErr error
|
|
}{
|
|
|
|
"happy path udp": {
|
|
question: "abc.com.",
|
|
configureRouter: func(router *agentdns.MockDNSRouter) {
|
|
router.On("HandleRequest", mock.Anything, mock.Anything, mock.Anything).
|
|
Return(basicResponse(), nil)
|
|
},
|
|
clientQuery: func(qR *pbdns.QueryRequest) {
|
|
qR.Protocol = pbdns.Protocol_PROTOCOL_UDP
|
|
},
|
|
},
|
|
"happy path tcp": {
|
|
question: "abc.com.",
|
|
configureRouter: func(router *agentdns.MockDNSRouter) {
|
|
router.On("HandleRequest", mock.Anything, mock.Anything, mock.Anything).
|
|
Return(basicResponse(), nil)
|
|
},
|
|
clientQuery: func(qR *pbdns.QueryRequest) {
|
|
qR.Protocol = pbdns.Protocol_PROTOCOL_TCP
|
|
},
|
|
},
|
|
"happy path with context variables set": {
|
|
question: "abc.com.",
|
|
configureRouter: func(router *agentdns.MockDNSRouter) {
|
|
router.On("HandleRequest", mock.Anything, mock.Anything, mock.Anything).
|
|
Run(func(args mock.Arguments) {
|
|
ctx, ok := args.Get(1).(agentdns.Context)
|
|
require.True(s.T(), ok, "error casting to agentdns.Context")
|
|
require.Equal(s.T(), "test-token", ctx.Token, "token not set in context")
|
|
require.Equal(s.T(), "test-namespace", ctx.DefaultNamespace, "namespace not set in context")
|
|
require.Equal(s.T(), "test-partition", ctx.DefaultPartition, "partition not set in context")
|
|
}).
|
|
Return(basicResponse(), nil)
|
|
},
|
|
clientQuery: func(qR *pbdns.QueryRequest) {
|
|
qR.Protocol = pbdns.Protocol_PROTOCOL_UDP
|
|
},
|
|
metadata: map[string]string{
|
|
"x-consul-token": "test-token",
|
|
"x-consul-namespace": "test-namespace",
|
|
"x-consul-partition": "test-partition",
|
|
},
|
|
},
|
|
"No protocol set": {
|
|
question: "abc.com.",
|
|
clientQuery: func(qR *pbdns.QueryRequest) {},
|
|
expectedErr: errors.New("error protocol type not set: PROTOCOL_UNSET_UNSPECIFIED"),
|
|
},
|
|
"Invalid question": {
|
|
question: "notvalid",
|
|
clientQuery: func(qR *pbdns.QueryRequest) {
|
|
qR.Protocol = pbdns.Protocol_PROTOCOL_UDP
|
|
},
|
|
expectedErr: errors.New("failure decoding dns request"),
|
|
},
|
|
}
|
|
|
|
for name, tc := range testCases {
|
|
s.Run(name, func() {
|
|
router := agentdns.NewMockDNSRouter(s.T())
|
|
|
|
if tc.configureRouter != nil {
|
|
tc.configureRouter(router)
|
|
}
|
|
|
|
server := NewServerV2(ConfigV2{
|
|
Logger: hclog.Default(),
|
|
DNSRouter: router,
|
|
TokenFunc: func() string { return "" },
|
|
})
|
|
|
|
client := testClient(s.T(), server)
|
|
|
|
req := dns.Msg{}
|
|
req.SetQuestion(tc.question, dns.TypeA)
|
|
|
|
bytes, _ := req.Pack()
|
|
|
|
ctx := context.Background()
|
|
if len(tc.metadata) > 0 {
|
|
md := metadata.MD{}
|
|
for k, v := range tc.metadata {
|
|
md.Set(k, v)
|
|
}
|
|
ctx = metadata.NewOutgoingContext(ctx, md)
|
|
}
|
|
|
|
clientReq := &pbdns.QueryRequest{Msg: bytes}
|
|
tc.clientQuery(clientReq)
|
|
clientResp, err := client.Query(ctx, clientReq)
|
|
if tc.expectedErr != nil {
|
|
s.Require().Error(err, "no errror calling gRPC endpoint")
|
|
s.Require().ErrorContains(err, tc.expectedErr.Error())
|
|
} else {
|
|
s.Require().NoError(err, "error calling gRPC endpoint")
|
|
|
|
resp := clientResp.GetMsg()
|
|
var dnsResp dns.Msg
|
|
|
|
err = dnsResp.Unpack(resp)
|
|
s.Require().NoError(err, "error unpacking dns response")
|
|
rr := dnsResp.Extra[0].(*dns.TXT)
|
|
s.Require().EqualValues(rr.Txt, txtRR)
|
|
}
|
|
})
|
|
}
|
|
}
|