Fix the issue in unqualified name where DNS client such as ping or iwr validate name in response and original question. Switch to use miekg's DNS library

pull/6/head
Jiangtian Li 2017-04-12 10:59:03 -07:00
parent 9a0f5ccb33
commit 1eda859bf9
4 changed files with 105 additions and 371 deletions

View File

@ -38,7 +38,6 @@ go_test(
name = "go_default_test", name = "go_default_test",
srcs = [ srcs = [
"proxier_test.go", "proxier_test.go",
"proxysocket_test.go",
"roundrobin_test.go", "roundrobin_test.go",
], ],
library = ":go_default_library", library = ":go_default_library",

View File

@ -262,7 +262,7 @@ func (proxier *Proxier) addServicePortPortal(servicePortPortalName ServicePortPo
socket: sock, socket: sock,
timeout: timeout, timeout: timeout,
activeClients: newClientCache(), activeClients: newClientCache(),
dnsClients: newDnsClientCache(), dnsClients: newDNSClientCache(),
sessionAffinityType: api.ServiceAffinityNone, // default sessionAffinityType: api.ServiceAffinityNone, // default
} }
proxier.setServiceInfo(servicePortPortalName, si) proxier.setServiceInfo(servicePortPortalName, si)

View File

@ -17,7 +17,6 @@ limitations under the License.
package winuserspace package winuserspace
import ( import (
"encoding/binary"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -28,6 +27,7 @@ import (
"time" "time"
"github.com/golang/glog" "github.com/golang/glog"
"github.com/miekg/dns"
"k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/runtime" "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/kubernetes/pkg/api" "k8s.io/kubernetes/pkg/api"
@ -237,190 +237,6 @@ func newClientCache() *clientCache {
return &clientCache{clients: map[string]net.Conn{}} return &clientCache{clients: map[string]net.Conn{}}
} }
// TODO: use Go net dnsmsg library to walk DNS message format
// DNS packet header
type dnsHeader struct {
id uint16
bits uint16
qdCount uint16
anCount uint16
nsCount uint16
arCount uint16
}
// DNS domain name
type dnsDomainName struct {
name string
}
// DNS packet question section
type dnsQuestion struct {
qName dnsDomainName
qType uint16
qClass uint16
}
// DNS message, only interested in question now
type dnsMsg struct {
header dnsHeader
question []dnsQuestion
}
type dnsStruct interface {
walk(f func(field interface{}) (ok bool)) (ok bool)
}
func (header *dnsHeader) walk(f func(field interface{}) bool) bool {
return f(&header.id) &&
f(&header.bits) &&
f(&header.qdCount) &&
f(&header.anCount) &&
f(&header.nsCount) &&
f(&header.arCount)
}
func (question *dnsQuestion) walk(f func(field interface{}) bool) bool {
return f(&question.qName) &&
f(&question.qType) &&
f(&question.qClass)
}
func packDomainName(name string, buffer []byte, index int) (newIndex int, ok bool) {
if name == "" {
buffer[index] = 0
index++
return index, true
}
// one more dot plus trailing 0
if index+len(name)+2 > len(buffer) {
return len(buffer), false
}
domains := strings.Split(name, ".")
for _, domain := range domains {
domainLen := len(domain)
if domainLen == 0 {
return len(buffer), false
}
buffer[index] = byte(domainLen)
index++
copy(buffer[index:index+domainLen], domain)
index += domainLen
}
buffer[index] = 0
index++
return index, true
}
func unpackDomainName(buffer []byte, index int) (name string, newIndex int, ok bool) {
name = ""
for index < len(buffer) {
cnt := int(buffer[index])
index++
if cnt == 0 {
break
}
if index+cnt > len(buffer) {
return "", len(buffer), false
}
if name != "" {
name += "."
}
name += string(buffer[index : index+cnt])
index += cnt
}
if index >= len(buffer) {
return "", len(buffer), false
}
return name, index, true
}
func packStruct(any dnsStruct, buffer []byte, index int) (newIndex int, ok bool) {
ok = any.walk(func(field interface{}) bool {
switch value := field.(type) {
case *uint16:
if index+2 > len(buffer) {
return false
}
binary.BigEndian.PutUint16(buffer[index:index+2], *value)
index += 2
return true
case *dnsDomainName:
index, ok = packDomainName((*value).name, buffer, index)
return ok
default:
return false
}
})
if !ok {
return len(buffer), false
}
return index, true
}
func unpackStruct(any dnsStruct, buffer []byte, index int) (newIndex int, ok bool) {
ok = any.walk(func(field interface{}) bool {
switch value := field.(type) {
case *uint16:
if index+2 > len(buffer) {
return false
}
*value = binary.BigEndian.Uint16(buffer[index : index+2])
index += 2
return true
case *dnsDomainName:
(*value).name, index, ok = unpackDomainName(buffer, index)
return ok
default:
return false
}
})
if !ok {
return len(buffer), false
}
return index, true
}
// Pack the message structure into buffer
func (msg *dnsMsg) packDnsMsg(buffer []byte) (length int, ok bool) {
index := 0
if index, ok = packStruct(&msg.header, buffer, index); !ok {
return len(buffer), false
}
for i := 0; i < len(msg.question); i++ {
if index, ok = packStruct(&msg.question[i], buffer, index); !ok {
return len(buffer), false
}
}
return index, true
}
// Unpack the buffer into the message structure
func (msg *dnsMsg) unpackDnsMsg(buffer []byte) (ok bool) {
index := 0
if index, ok = unpackStruct(&msg.header, buffer, index); !ok {
return false
}
msg.question = make([]dnsQuestion, msg.header.qdCount)
for i := 0; i < len(msg.question); i++ {
if index, ok = unpackStruct(&msg.question[i], buffer, index); !ok {
return false
}
}
return true
}
// DNS query client classified by address and QTYPE // DNS query client classified by address and QTYPE
type dnsClientQuery struct { type dnsClientQuery struct {
clientAddress string clientAddress string
@ -436,44 +252,85 @@ type dnsClientCache struct {
type dnsQueryState struct { type dnsQueryState struct {
searchIndex int32 searchIndex int32
msg *dnsMsg msg *dns.Msg
} }
func newDnsClientCache() *dnsClientCache { func newDNSClientCache() *dnsClientCache {
return &dnsClientCache{clients: map[dnsClientQuery]*dnsQueryState{}} return &dnsClientCache{clients: map[dnsClientQuery]*dnsQueryState{}}
} }
func packetRequiresDnsSuffix(dnsType, dnsClass uint16) bool { func packetRequiresDNSSuffix(dnsType, dnsClass uint16) bool {
return (dnsType == dnsTypeA || dnsType == dnsTypeAAAA) && dnsClass == dnsClassInternet return (dnsType == dnsTypeA || dnsType == dnsTypeAAAA) && dnsClass == dnsClassInternet
} }
func isDnsService(portName string) bool { func isDNSService(portName string) bool {
return portName == dnsPortName return portName == dnsPortName
} }
func appendDnsSuffix(msg *dnsMsg, buffer []byte, length int, dnsSuffix string) int { func appendDNSSuffix(msg *dns.Msg, buffer []byte, length int, dnsSuffix string) int {
if msg == nil || len(msg.question) == 0 { if msg == nil || len(msg.Question) == 0 {
glog.Warning("DNS message parameter is invalid.") glog.Warning("DNS message parameter is invalid.")
return length return length
} }
// Save the original name since it will be reused for next iteration // Save the original name since it will be reused for next iteration
origName := msg.question[0].qName.name origName := msg.Question[0].Name
if dnsSuffix != "" { if dnsSuffix != "" {
msg.question[0].qName.name += "." + dnsSuffix msg.Question[0].Name += dnsSuffix + "."
} }
len, ok := msg.packDnsMsg(buffer) mbuf, err := msg.PackBuffer(buffer)
msg.question[0].qName.name = origName msg.Question[0].Name = origName
if !ok { if err != nil {
glog.Warning("Unable to pack DNS packet.") glog.Warning("Unable to pack DNS packet. Error is: %v", err)
return length return length
} }
return len if &buffer[0] != &mbuf[0] {
glog.Warning("Buffer is too small in packing DNS packet.")
return length
}
return len(mbuf)
} }
func processUnpackedDnsQueryPacket(dnsClients *dnsClientCache, msg *dnsMsg, host string, dnsQType uint16, buffer []byte, length int, dnsSearch []string) int { func recoverDNSQuestion(origName string, msg *dns.Msg, buffer []byte, length int) int {
if msg == nil || len(msg.Question) == 0 {
glog.Warning("DNS message parameter is invalid.")
return length
}
if origName == msg.Question[0].Name {
return length
}
msg.Question[0].Name = origName
if len(msg.Answer) > 0 {
msg.Answer[0].Header().Name = origName
}
mbuf, err := msg.PackBuffer(buffer)
if err != nil {
glog.Warning("Unable to pack DNS packet. Error is: %v", err)
return length
}
if &buffer[0] != &mbuf[0] {
glog.Warning("Buffer is too small in packing DNS packet.")
return length
}
return len(mbuf)
}
func processUnpackedDNSQueryPacket(
dnsClients *dnsClientCache,
msg *dns.Msg,
host string,
dnsQType uint16,
buffer []byte,
length int,
dnsSearch []string) int {
if dnsSearch == nil || len(dnsSearch) == 0 { if dnsSearch == nil || len(dnsSearch) == 0 {
glog.V(1).Infof("DNS search list is not initialized and is empty.") glog.V(1).Infof("DNS search list is not initialized and is empty.")
return length return length
@ -490,22 +347,31 @@ func processUnpackedDnsQueryPacket(dnsClients *dnsClientCache, msg *dnsMsg, host
index := atomic.SwapInt32(&state.searchIndex, state.searchIndex+1) index := atomic.SwapInt32(&state.searchIndex, state.searchIndex+1)
// Also update message ID if the client retries due to previous query time out // Also update message ID if the client retries due to previous query time out
state.msg.header.id = msg.header.id state.msg.MsgHdr.Id = msg.MsgHdr.Id
if index < 0 || index >= int32(len(dnsSearch)) { if index < 0 || index >= int32(len(dnsSearch)) {
glog.V(1).Infof("Search index %d is out of range.", index) glog.V(1).Infof("Search index %d is out of range.", index)
return length return length
} }
length = appendDnsSuffix(msg, buffer, length, dnsSearch[index]) length = appendDNSSuffix(msg, buffer, length, dnsSearch[index])
return length return length
} }
func processUnpackedDnsResponsePacket(svrConn net.Conn, dnsClients *dnsClientCache, rcode uint16, host string, dnsQType uint16, buffer []byte, length int, dnsSearch []string) bool { func processUnpackedDNSResponsePacket(
svrConn net.Conn,
dnsClients *dnsClientCache,
msg *dns.Msg,
rcode int,
host string,
dnsQType uint16,
buffer []byte,
length int,
dnsSearch []string) (bool, int) {
var drop bool var drop bool
if dnsSearch == nil || len(dnsSearch) == 0 { if dnsSearch == nil || len(dnsSearch) == 0 {
glog.V(1).Infof("DNS search list is not initialized and is empty.") glog.V(1).Infof("DNS search list is not initialized and is empty.")
return drop return drop, length
} }
dnsClients.mu.Lock() dnsClients.mu.Lock()
@ -518,7 +384,7 @@ func processUnpackedDnsResponsePacket(svrConn net.Conn, dnsClients *dnsClientCac
// If the reponse has failure and iteration through the search list has not // If the reponse has failure and iteration through the search list has not
// reached the end, retry on behalf of the client using the original query message // reached the end, retry on behalf of the client using the original query message
drop = true drop = true
length = appendDnsSuffix(state.msg, buffer, length, dnsSearch[index]) length = appendDNSSuffix(state.msg, buffer, length, dnsSearch[index])
_, err := svrConn.Write(buffer[0:length]) _, err := svrConn.Write(buffer[0:length])
if err != nil { if err != nil {
@ -527,98 +393,96 @@ func processUnpackedDnsResponsePacket(svrConn net.Conn, dnsClients *dnsClientCac
} }
} }
} else { } else {
length = recoverDNSQuestion(state.msg.Question[0].Name, msg, buffer, length)
dnsClients.mu.Lock() dnsClients.mu.Lock()
delete(dnsClients.clients, dnsClientQuery{host, dnsQType}) delete(dnsClients.clients, dnsClientQuery{host, dnsQType})
dnsClients.mu.Unlock() dnsClients.mu.Unlock()
} }
} }
return drop return drop, length
} }
func processDnsQueryPacket(dnsClients *dnsClientCache, cliAddr net.Addr, buffer []byte, length int, dnsSearch []string) int { func processDNSQueryPacket(dnsClients *dnsClientCache, cliAddr net.Addr, buffer []byte, length int, dnsSearch []string) int {
msg := &dnsMsg{} msg := &dns.Msg{}
if !msg.unpackDnsMsg(buffer[:length]) { if err := msg.Unpack(buffer[:length]); err != nil {
glog.Warning("Unable to unpack DNS packet.") glog.Warning("Unable to unpack DNS packet. Error is: %v", err)
return length return length
} }
// Query - Response bit that specifies whether this message is a query (0) or a response (1). // Query - Response bit that specifies whether this message is a query (0) or a response (1).
qr := msg.header.bits & 0x8000 if msg.MsgHdr.Response == true {
if qr != 0 {
glog.Warning("DNS packet should be a query message.") glog.Warning("DNS packet should be a query message.")
return length return length
} }
// QDCOUNT // QDCOUNT
if msg.header.qdCount != 1 { if len(msg.Question) != 1 {
glog.V(1).Infof("Number of entries in the question section of the DNS packet is: %d", msg.header.qdCount) glog.V(1).Infof("Number of entries in the question section of the DNS packet is: %d", len(msg.Question))
glog.V(1).Infof("DNS suffix appending does not support more than one question.") glog.V(1).Infof("DNS suffix appending does not support more than one question.")
return length return length
} }
// ANCOUNT, NSCOUNT, ARCOUNT // ANCOUNT, NSCOUNT, ARCOUNT
if msg.header.anCount != 0 || msg.header.nsCount != 0 || msg.header.arCount != 0 { if len(msg.Answer) != 0 || len(msg.Ns) != 0 || len(msg.Extra) != 0 {
glog.V(1).Infof("DNS packet contains more than question section.") glog.V(1).Infof("DNS packet contains more than question section.")
return length return length
} }
dnsQType := msg.question[0].qType dnsQType := msg.Question[0].Qtype
dnsQClass := msg.question[0].qClass dnsQClass := msg.Question[0].Qclass
if packetRequiresDnsSuffix(dnsQType, dnsQClass) { if packetRequiresDNSSuffix(dnsQType, dnsQClass) {
host, _, err := net.SplitHostPort(cliAddr.String()) host, _, err := net.SplitHostPort(cliAddr.String())
if err != nil { if err != nil {
glog.V(1).Infof("Failed to get host from client address: %v", err) glog.V(1).Infof("Failed to get host from client address: %v", err)
host = cliAddr.String() host = cliAddr.String()
} }
length = processUnpackedDnsQueryPacket(dnsClients, msg, host, dnsQType, buffer, length, dnsSearch) length = processUnpackedDNSQueryPacket(dnsClients, msg, host, dnsQType, buffer, length, dnsSearch)
} }
return length return length
} }
func processDnsResponsePacket(svrConn net.Conn, dnsClients *dnsClientCache, cliAddr net.Addr, buffer []byte, length int, dnsSearch []string) bool { func processDNSResponsePacket(svrConn net.Conn, dnsClients *dnsClientCache, cliAddr net.Addr, buffer []byte, length int, dnsSearch []string) (bool, int) {
var drop bool var drop bool
msg := &dnsMsg{} msg := &dns.Msg{}
if !msg.unpackDnsMsg(buffer[:length]) { if err := msg.Unpack(buffer[:length]); err != nil {
glog.Warning("Unable to unpack DNS packet.") glog.Warning("Unable to unpack DNS packet. Error is: %v", err)
return drop return drop, length
} }
// Query - Response bit that specifies whether this message is a query (0) or a response (1). // Query - Response bit that specifies whether this message is a query (0) or a response (1).
qr := msg.header.bits & 0x8000 if msg.MsgHdr.Response == false {
if qr == 0 {
glog.Warning("DNS packet should be a response message.") glog.Warning("DNS packet should be a response message.")
return drop return drop, length
} }
// QDCOUNT // QDCOUNT
if msg.header.qdCount != 1 { if len(msg.Question) != 1 {
glog.V(1).Infof("Number of entries in the reponse section of the DNS packet is: %d", msg.header.qdCount) glog.V(1).Infof("Number of entries in the reponse section of the DNS packet is: %d", len(msg.Answer))
return drop return drop, length
} }
dnsQType := msg.question[0].qType dnsQType := msg.Question[0].Qtype
dnsQClass := msg.question[0].qClass dnsQClass := msg.Question[0].Qclass
if packetRequiresDnsSuffix(dnsQType, dnsQClass) { if packetRequiresDNSSuffix(dnsQType, dnsQClass) {
host, _, err := net.SplitHostPort(cliAddr.String()) host, _, err := net.SplitHostPort(cliAddr.String())
if err != nil { if err != nil {
glog.V(1).Infof("Failed to get host from client address: %v", err) glog.V(1).Infof("Failed to get host from client address: %v", err)
host = cliAddr.String() host = cliAddr.String()
} }
rcode := msg.header.bits & 0xf drop, length = processUnpackedDNSResponsePacket(svrConn, dnsClients, msg, msg.MsgHdr.Rcode, host, dnsQType, buffer, length, dnsSearch)
drop = processUnpackedDnsResponsePacket(svrConn, dnsClients, rcode, host, dnsQType, buffer, length, dnsSearch)
} }
return drop return drop, length
} }
func (udp *udpProxySocket) ProxyLoop(service ServicePortPortalName, myInfo *serviceInfo, proxier *Proxier) { func (udp *udpProxySocket) ProxyLoop(service ServicePortPortalName, myInfo *serviceInfo, proxier *Proxier) {
var buffer [4096]byte // 4KiB should be enough for most whole-packets var buffer [4096]byte // 4KiB should be enough for most whole-packets
var dnsSearch []string var dnsSearch []string
if isDnsService(service.Port) { if isDNSService(service.Port) {
dnsSearch = []string{"", namespaceServiceDomain, serviceDomain, clusterDomain} dnsSearch = []string{"", namespaceServiceDomain, serviceDomain, clusterDomain}
execer := exec.New() execer := exec.New()
ipconfigInterface := ipconfig.New(execer) ipconfigInterface := ipconfig.New(execer)
@ -651,8 +515,8 @@ func (udp *udpProxySocket) ProxyLoop(service ServicePortPortalName, myInfo *serv
} }
// If this is DNS query packet // If this is DNS query packet
if isDnsService(service.Port) { if isDNSService(service.Port) {
n = processDnsQueryPacket(myInfo.dnsClients, cliAddr, buffer[:], n, dnsSearch) n = processDNSQueryPacket(myInfo.dnsClients, cliAddr, buffer[:], n, dnsSearch)
} }
// If this is a client we know already, reuse the connection and goroutine. // If this is a client we know already, reuse the connection and goroutine.
@ -720,8 +584,8 @@ func (udp *udpProxySocket) proxyClient(cliAddr net.Addr, svrConn net.Conn, activ
} }
drop := false drop := false
if isDnsService(service.Port) { if isDNSService(service.Port) {
drop = processDnsResponsePacket(svrConn, dnsClients, cliAddr, buffer[:], n, dnsSearch) drop, n = processDNSResponsePacket(svrConn, dnsClients, cliAddr, buffer[:], n, dnsSearch)
} }
if !drop { if !drop {

View File

@ -1,129 +0,0 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package winuserspace
import (
"reflect"
"testing"
)
func TestPackUnpackDnsMsgUnqualifiedName(t *testing.T) {
msg := &dnsMsg{}
var buffer [4096]byte
msg.header.id = 1
msg.header.qdCount = 1
msg.question = make([]dnsQuestion, msg.header.qdCount)
msg.question[0].qClass = 0x01
msg.question[0].qType = 0x01
msg.question[0].qName.name = "kubernetes"
length, ok := msg.packDnsMsg(buffer[:])
if !ok {
t.Errorf("Pack DNS message failed.")
}
unpackedMsg := &dnsMsg{}
if !unpackedMsg.unpackDnsMsg(buffer[:length]) {
t.Errorf("Unpack DNS message failed.")
}
if !reflect.DeepEqual(msg, unpackedMsg) {
t.Errorf("Pack and Unpack DNS message are not consistent.")
}
}
func TestPackUnpackDnsMsgFqdn(t *testing.T) {
msg := &dnsMsg{}
var buffer [4096]byte
msg.header.id = 1
msg.header.qdCount = 1
msg.question = make([]dnsQuestion, msg.header.qdCount)
msg.question[0].qClass = 0x01
msg.question[0].qType = 0x01
msg.question[0].qName.name = "kubernetes.default.svc.cluster.local"
length, ok := msg.packDnsMsg(buffer[:])
if !ok {
t.Errorf("Pack DNS message failed.")
}
unpackedMsg := &dnsMsg{}
if !unpackedMsg.unpackDnsMsg(buffer[:length]) {
t.Errorf("Unpack DNS message failed.")
}
if !reflect.DeepEqual(msg, unpackedMsg) {
t.Errorf("Pack and Unpack DNS message are not consistent.")
}
}
func TestPackUnpackDnsMsgEmptyName(t *testing.T) {
msg := &dnsMsg{}
var buffer [4096]byte
msg.header.id = 1
msg.header.qdCount = 1
msg.question = make([]dnsQuestion, msg.header.qdCount)
msg.question[0].qClass = 0x01
msg.question[0].qType = 0x01
msg.question[0].qName.name = ""
length, ok := msg.packDnsMsg(buffer[:])
if !ok {
t.Errorf("Pack DNS message failed.")
}
unpackedMsg := &dnsMsg{}
if !unpackedMsg.unpackDnsMsg(buffer[:length]) {
t.Errorf("Unpack DNS message failed.")
}
if !reflect.DeepEqual(msg, unpackedMsg) {
t.Errorf("Pack and Unpack DNS message are not consistent.")
}
}
func TestPackUnpackDnsMsgMultipleQuestions(t *testing.T) {
msg := &dnsMsg{}
var buffer [4096]byte
msg.header.id = 1
msg.header.qdCount = 2
msg.question = make([]dnsQuestion, msg.header.qdCount)
msg.question[0].qClass = 0x01
msg.question[0].qType = 0x01
msg.question[0].qName.name = "kubernetes"
msg.question[1].qClass = 0x01
msg.question[1].qType = 0x1c
msg.question[1].qName.name = "kubernetes.default"
length, ok := msg.packDnsMsg(buffer[:])
if !ok {
t.Errorf("Pack DNS message failed.")
}
unpackedMsg := &dnsMsg{}
if !unpackedMsg.unpackDnsMsg(buffer[:length]) {
t.Errorf("Unpack DNS message failed.")
}
if !reflect.DeepEqual(msg, unpackedMsg) {
t.Errorf("Pack and Unpack DNS message are not consistent.")
}
}