From 93f9b54cab80dd0f49d1041141b60133a82baaca Mon Sep 17 00:00:00 2001 From: bprashanth Date: Mon, 26 Sep 2016 11:24:35 -0700 Subject: [PATCH 1/2] NodePorts understand OnlyLocal --- pkg/proxy/iptables/proxier.go | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/pkg/proxy/iptables/proxier.go b/pkg/proxy/iptables/proxier.go index 36a5cbfb1d..3799004ee6 100644 --- a/pkg/proxy/iptables/proxier.go +++ b/pkg/proxy/iptables/proxier.go @@ -1070,10 +1070,14 @@ func (proxier *Proxier) syncProxyRules() { "-m", protocol, "-p", protocol, "--dport", fmt.Sprintf("%d", svcInfo.nodePort), } - // Nodeports need SNAT. - writeLine(natRules, append(args, "-j", string(KubeMarkMasqChain))...) - // Jump to the service chain. - writeLine(natRules, append(args, "-j", string(svcChain))...) + if !svcInfo.onlyNodeLocalEndpoints { + // Nodeports need SNAT, unless they're local. + writeLine(natRules, append(args, "-j", string(KubeMarkMasqChain))...) + // Jump to the service chain. + writeLine(natRules, append(args, "-j", string(svcChain))...) + } else { + writeLine(natRules, append(args, "-j", string(svcXlbChain))...) + } } // If the service has no endpoints then reject packets. @@ -1173,6 +1177,16 @@ func (proxier *Proxier) syncProxyRules() { localEndpointChains = append(localEndpointChains, endpointChains[i]) } } + // First rule in the chain redirects all pod -> external vip traffic to the + // Service's ClusterIP instead. This happens whether or not we have local + // endpoints. + args = []string{ + "-A", string(svcXlbChain), + "-m", "comment", "--comment", + fmt.Sprintf(`"Redirect pods trying to reach external loadbalancer VIP to clusterIP"`), + } + writeLine(natRules, append(args, "-s", proxier.clusterCIDR, "-j", string(svcChain))...) + numLocalEndpoints := len(localEndpointChains) if numLocalEndpoints == 0 { // Blackhole all traffic since there are no local endpoints From 06cbb36a1ff2f792de0e99d83c5541b8110b8084 Mon Sep 17 00:00:00 2001 From: bprashanth Date: Mon, 26 Sep 2016 19:48:21 -0700 Subject: [PATCH 2/2] Proxier unittests --- pkg/proxy/iptables/proxier.go | 22 ++- pkg/proxy/iptables/proxier_test.go | 276 ++++++++++++++++++++++++++++- pkg/util/iptables/testing/fake.go | 85 +++++++-- 3 files changed, 354 insertions(+), 29 deletions(-) diff --git a/pkg/proxy/iptables/proxier.go b/pkg/proxy/iptables/proxier.go index 3799004ee6..0feda54e01 100644 --- a/pkg/proxy/iptables/proxier.go +++ b/pkg/proxy/iptables/proxier.go @@ -177,6 +177,7 @@ type Proxier struct { clusterCIDR string hostname string nodeIP net.IP + portMapper portOpener } type localPort struct { @@ -194,6 +195,20 @@ type closeable interface { Close() error } +// portOpener is an interface around port opening/closing. +// Abstracted out for testing. +type portOpener interface { + OpenLocalPort(lp *localPort) (closeable, error) +} + +// listenPortOpener opens ports by calling bind() and listen(). +type listenPortOpener struct{} + +// OpenLocalPort holds the given local port open. +func (l *listenPortOpener) OpenLocalPort(lp *localPort) (closeable, error) { + return openLocalPort(lp) +} + // Proxier implements ProxyProvider var _ proxy.ProxyProvider = &Proxier{} @@ -241,6 +256,7 @@ func NewProxier(ipt utiliptables.Interface, sysctl utilsysctl.Interface, exec ut clusterCIDR: clusterCIDR, hostname: hostname, nodeIP: nodeIP, + portMapper: &listenPortOpener{}, }, nil } @@ -941,7 +957,7 @@ func (proxier *Proxier) syncProxyRules() { glog.V(4).Infof("Port %s was open before and is still needed", lp.String()) replacementPortsMap[lp] = proxier.portsMap[lp] } else { - socket, err := openLocalPort(&lp) + socket, err := proxier.portMapper.OpenLocalPort(&lp) if err != nil { glog.Errorf("can't open %s, skipping this externalIP: %v", lp.String(), err) continue @@ -1056,7 +1072,7 @@ func (proxier *Proxier) syncProxyRules() { glog.V(4).Infof("Port %s was open before and is still needed", lp.String()) replacementPortsMap[lp] = proxier.portsMap[lp] } else { - socket, err := openLocalPort(&lp) + socket, err := proxier.portMapper.OpenLocalPort(&lp) if err != nil { glog.Errorf("can't open %s, skipping this nodePort: %v", lp.String(), err) continue @@ -1076,6 +1092,8 @@ func (proxier *Proxier) syncProxyRules() { // Jump to the service chain. writeLine(natRules, append(args, "-j", string(svcChain))...) } else { + // TODO: Make all nodePorts jump to the firewall chain. + // Currently we only create it for loadbalancers (#33586). writeLine(natRules, append(args, "-j", string(svcXlbChain))...) } } diff --git a/pkg/proxy/iptables/proxier_test.go b/pkg/proxy/iptables/proxier_test.go index c9c1fb36d3..ad0a028f10 100644 --- a/pkg/proxy/iptables/proxier_test.go +++ b/pkg/proxy/iptables/proxier_test.go @@ -28,6 +28,7 @@ import ( "k8s.io/kubernetes/pkg/types" "k8s.io/kubernetes/pkg/util/exec" utiliptables "k8s.io/kubernetes/pkg/util/iptables" + iptablestest "k8s.io/kubernetes/pkg/util/iptables/testing" ) func checkAllLines(t *testing.T, table utiliptables.Table, save []byte, expectedLines map[utiliptables.Chain]string) { @@ -256,12 +257,13 @@ func TestExecConntrackTool(t *testing.T) { } } -func newFakeServiceInfo(service proxy.ServicePortName, ip net.IP, protocol api.Protocol) *serviceInfo { +func newFakeServiceInfo(service proxy.ServicePortName, ip net.IP, protocol api.Protocol, onlyNodeLocalEndpoints bool) *serviceInfo { return &serviceInfo{ - sessionAffinityType: api.ServiceAffinityNone, // default - stickyMaxAgeSeconds: 180, // TODO: paramaterize this in the API. - clusterIP: ip, - protocol: protocol, + sessionAffinityType: api.ServiceAffinityNone, // default + stickyMaxAgeSeconds: 180, // TODO: paramaterize this in the API. + clusterIP: ip, + protocol: protocol, + onlyNodeLocalEndpoints: onlyNodeLocalEndpoints, } } @@ -285,8 +287,8 @@ func TestDeleteEndpointConnections(t *testing.T) { serviceMap := make(map[proxy.ServicePortName]*serviceInfo) svc1 := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: "svc1"}, Port: ""} svc2 := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: "svc2"}, Port: ""} - serviceMap[svc1] = newFakeServiceInfo(svc1, net.IPv4(10, 20, 30, 40), api.ProtocolUDP) - serviceMap[svc2] = newFakeServiceInfo(svc1, net.IPv4(10, 20, 30, 41), api.ProtocolTCP) + serviceMap[svc1] = newFakeServiceInfo(svc1, net.IPv4(10, 20, 30, 40), api.ProtocolUDP, false) + serviceMap[svc2] = newFakeServiceInfo(svc1, net.IPv4(10, 20, 30, 41), api.ProtocolTCP, false) fakeProxier := Proxier{exec: &fexec, serviceMap: serviceMap} @@ -473,4 +475,262 @@ func TestRevertPorts(t *testing.T) { } -// TODO(thockin): add a test for syncProxyRules() or break it down further and test the pieces. +// fakePortOpener implements portOpener. +type fakePortOpener struct { + openPorts []*localPort +} + +// OpenLocalPort fakes out the listen() and bind() used by syncProxyRules +// to lock a local port. +func (f *fakePortOpener) OpenLocalPort(lp *localPort) (closeable, error) { + f.openPorts = append(f.openPorts, lp) + return nil, nil +} + +func NewFakeProxier(ipt utiliptables.Interface) *Proxier { + // TODO: Call NewProxier after refactoring out the goroutine + // invocation into a Run() method. + return &Proxier{ + exec: &exec.FakeExec{}, + serviceMap: make(map[proxy.ServicePortName]*serviceInfo), + iptables: ipt, + endpointsMap: make(map[proxy.ServicePortName][]*endpointsInfo), + clusterCIDR: "10.0.0.0/24", + haveReceivedEndpointsUpdate: true, + haveReceivedServiceUpdate: true, + hostname: "test-hostname", + portsMap: make(map[localPort]closeable), + portMapper: &fakePortOpener{[]*localPort{}}, + } +} + +func hasJump(rules []iptablestest.Rule, destChain, destIP, destPort string) bool { + for _, r := range rules { + if r[iptablestest.Jump] == destChain { + if destIP != "" { + return strings.Contains(r[iptablestest.Destination], destIP) + } + if destPort != "" { + return strings.Contains(r[iptablestest.DPort], destPort) + } + return true + } + } + return false +} + +func hasDNAT(rules []iptablestest.Rule, endpoint string) bool { + for _, r := range rules { + if r[iptablestest.ToDest] == endpoint { + return true + } + } + return false +} + +func errorf(msg string, rules []iptablestest.Rule, t *testing.T) { + for _, r := range rules { + t.Logf("%v", r) + } + t.Errorf("%v", msg) +} + +func TestClusterIPReject(t *testing.T) { + ipt := iptablestest.NewFake() + fp := NewFakeProxier(ipt) + svcName := "svc1" + svcIP := net.IPv4(10, 20, 30, 41) + + svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: ""} + fp.serviceMap[svc] = newFakeServiceInfo(svc, svcIP, api.ProtocolTCP, false) + fp.syncProxyRules() + + svcChain := string(servicePortChainName(svc, strings.ToLower(string(api.ProtocolTCP)))) + svcRules := ipt.GetRules(svcChain) + if len(svcRules) != 0 { + errorf(fmt.Sprintf("Unexpected rule for chain %v service %v without endpoints", svcChain, svcName), svcRules, t) + } + kubeSvcRules := ipt.GetRules(string(kubeServicesChain)) + if !hasJump(kubeSvcRules, iptablestest.Reject, svcIP.String(), "") { + errorf(fmt.Sprintf("Failed to find a %v rule for service %v with no endpoints", iptablestest.Reject, svcName), kubeSvcRules, t) + } +} + +func TestClusterIPEndpointsJump(t *testing.T) { + ipt := iptablestest.NewFake() + fp := NewFakeProxier(ipt) + svcName := "svc1" + svcIP := net.IPv4(10, 20, 30, 41) + + svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: "80"} + fp.serviceMap[svc] = newFakeServiceInfo(svc, svcIP, api.ProtocolTCP, true) + ep := "10.180.0.1:80" + fp.endpointsMap[svc] = []*endpointsInfo{{ep, false}} + + fp.syncProxyRules() + + svcChain := string(servicePortChainName(svc, strings.ToLower(string(api.ProtocolTCP)))) + epChain := string(servicePortEndpointChainName(svc, strings.ToLower(string(api.ProtocolTCP)), ep)) + + kubeSvcRules := ipt.GetRules(string(kubeServicesChain)) + if !hasJump(kubeSvcRules, svcChain, svcIP.String(), "") { + errorf(fmt.Sprintf("Failed to find jump from KUBE-SERVICES to %v chain", svcChain), kubeSvcRules, t) + } + + svcRules := ipt.GetRules(svcChain) + if !hasJump(svcRules, epChain, "", "") { + errorf(fmt.Sprintf("Failed to jump to ep chain %v", epChain), svcRules, t) + } + epRules := ipt.GetRules(epChain) + if !hasDNAT(epRules, ep) { + errorf(fmt.Sprintf("Endpoint chain %v lacks DNAT to %v", epChain, ep), epRules, t) + } +} + +func typeLoadBalancer(svcInfo *serviceInfo) *serviceInfo { + svcInfo.nodePort = 3001 + svcInfo.loadBalancerStatus = api.LoadBalancerStatus{ + Ingress: []api.LoadBalancerIngress{{IP: "1.2.3.4"}}, + } + return svcInfo +} + +func TestLoadBalancer(t *testing.T) { + ipt := iptablestest.NewFake() + fp := NewFakeProxier(ipt) + svcName := "svc1" + svcIP := net.IPv4(10, 20, 30, 41) + + svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: "80"} + svcInfo := newFakeServiceInfo(svc, svcIP, api.ProtocolTCP, false) + fp.serviceMap[svc] = typeLoadBalancer(svcInfo) + + ep1 := "10.180.0.1:80" + fp.endpointsMap[svc] = []*endpointsInfo{{ep1, false}} + + fp.syncProxyRules() + + proto := strings.ToLower(string(api.ProtocolTCP)) + fwChain := string(serviceFirewallChainName(svc, proto)) + svcChain := string(servicePortChainName(svc, strings.ToLower(string(api.ProtocolTCP)))) + //lbChain := string(serviceLBChainName(svc, proto)) + + kubeSvcRules := ipt.GetRules(string(kubeServicesChain)) + if !hasJump(kubeSvcRules, fwChain, svcInfo.loadBalancerStatus.Ingress[0].IP, "") { + errorf(fmt.Sprintf("Failed to find jump to firewall chain %v", fwChain), kubeSvcRules, t) + } + + fwRules := ipt.GetRules(fwChain) + if !hasJump(fwRules, svcChain, "", "") || !hasJump(fwRules, string(KubeMarkMasqChain), "", "") { + errorf(fmt.Sprintf("Failed to find jump from firewall chain %v to svc chain %v", fwChain, svcChain), fwRules, t) + } +} + +func TestNodePort(t *testing.T) { + ipt := iptablestest.NewFake() + fp := NewFakeProxier(ipt) + svcName := "svc1" + svcIP := net.IPv4(10, 20, 30, 41) + + svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: "80"} + svcInfo := newFakeServiceInfo(svc, svcIP, api.ProtocolTCP, false) + svcInfo.nodePort = 3001 + fp.serviceMap[svc] = svcInfo + + ep1 := "10.180.0.1:80" + fp.endpointsMap[svc] = []*endpointsInfo{{ep1, false}} + + fp.syncProxyRules() + + proto := strings.ToLower(string(api.ProtocolTCP)) + svcChain := string(servicePortChainName(svc, strings.ToLower(proto))) + + kubeNodePortRules := ipt.GetRules(string(kubeNodePortsChain)) + if !hasJump(kubeNodePortRules, svcChain, "", fmt.Sprintf("%v", svcInfo.nodePort)) { + errorf(fmt.Sprintf("Failed to find jump to svc chain %v", svcChain), kubeNodePortRules, t) + } +} + +func TestOnlyLocalLoadBalancing(t *testing.T) { + ipt := iptablestest.NewFake() + fp := NewFakeProxier(ipt) + svcName := "svc1" + svcIP := net.IPv4(10, 20, 30, 41) + + svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: "80"} + svcInfo := newFakeServiceInfo(svc, svcIP, api.ProtocolTCP, true) + fp.serviceMap[svc] = typeLoadBalancer(svcInfo) + + nonLocalEp := "10.180.0.1:80" + localEp := "10.180.2.1:80" + fp.endpointsMap[svc] = []*endpointsInfo{{nonLocalEp, false}, {localEp, true}} + + fp.syncProxyRules() + + proto := strings.ToLower(string(api.ProtocolTCP)) + fwChain := string(serviceFirewallChainName(svc, proto)) + lbChain := string(serviceLBChainName(svc, proto)) + + nonLocalEpChain := string(servicePortEndpointChainName(svc, strings.ToLower(string(api.ProtocolTCP)), nonLocalEp)) + localEpChain := string(servicePortEndpointChainName(svc, strings.ToLower(string(api.ProtocolTCP)), localEp)) + + kubeSvcRules := ipt.GetRules(string(kubeServicesChain)) + if !hasJump(kubeSvcRules, fwChain, svcInfo.loadBalancerStatus.Ingress[0].IP, "") { + errorf(fmt.Sprintf("Failed to find jump to firewall chain %v", fwChain), kubeSvcRules, t) + } + + fwRules := ipt.GetRules(fwChain) + if !hasJump(fwRules, lbChain, "", "") { + errorf(fmt.Sprintf("Failed to find jump from firewall chain %v to svc chain %v", fwChain, lbChain), fwRules, t) + } + if hasJump(fwRules, string(KubeMarkMasqChain), "", "") { + errorf(fmt.Sprintf("Found jump from fw chain %v to MASQUERADE", fwChain), fwRules, t) + } + + lbRules := ipt.GetRules(lbChain) + if hasJump(lbRules, nonLocalEpChain, "", "") { + errorf(fmt.Sprintf("Found jump from lb chain %v to non-local ep %v", lbChain, nonLocalEp), lbRules, t) + } + if !hasJump(lbRules, localEpChain, "", "") { + errorf(fmt.Sprintf("Didn't find jump from lb chain %v to local ep %v", lbChain, nonLocalEp), lbRules, t) + } +} + +func TestOnlyLocalNodePorts(t *testing.T) { + ipt := iptablestest.NewFake() + fp := NewFakeProxier(ipt) + svcName := "svc1" + svcIP := net.IPv4(10, 20, 30, 41) + + svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: "80"} + svcInfo := newFakeServiceInfo(svc, svcIP, api.ProtocolTCP, true) + svcInfo.nodePort = 3001 + fp.serviceMap[svc] = svcInfo + + nonLocalEp := "10.180.0.1:80" + localEp := "10.180.2.1:80" + fp.endpointsMap[svc] = []*endpointsInfo{{nonLocalEp, false}, {localEp, true}} + + fp.syncProxyRules() + + proto := strings.ToLower(string(api.ProtocolTCP)) + lbChain := string(serviceLBChainName(svc, proto)) + + nonLocalEpChain := string(servicePortEndpointChainName(svc, strings.ToLower(string(api.ProtocolTCP)), nonLocalEp)) + localEpChain := string(servicePortEndpointChainName(svc, strings.ToLower(string(api.ProtocolTCP)), localEp)) + + kubeNodePortRules := ipt.GetRules(string(kubeNodePortsChain)) + if !hasJump(kubeNodePortRules, lbChain, "", fmt.Sprintf("%v", svcInfo.nodePort)) { + errorf(fmt.Sprintf("Failed to find jump to lb chain %v", lbChain), kubeNodePortRules, t) + } + + lbRules := ipt.GetRules(lbChain) + if hasJump(lbRules, nonLocalEpChain, "", "") { + errorf(fmt.Sprintf("Found jump from lb chain %v to non-local ep %v", lbChain, nonLocalEp), lbRules, t) + } + if !hasJump(lbRules, localEpChain, "", "") { + errorf(fmt.Sprintf("Didn't find jump from lb chain %v to local ep %v", lbChain, nonLocalEp), lbRules, t) + } +} + +// TODO(thockin): add *more* tests for syncProxyRules() or break it down further and test the pieces. diff --git a/pkg/util/iptables/testing/fake.go b/pkg/util/iptables/testing/fake.go index 885dea5840..f3b0be57cf 100644 --- a/pkg/util/iptables/testing/fake.go +++ b/pkg/util/iptables/testing/fake.go @@ -16,60 +16,107 @@ limitations under the License. package testing -import "k8s.io/kubernetes/pkg/util/iptables" +import ( + "fmt" + "strings" + + "k8s.io/kubernetes/pkg/util/iptables" +) + +const ( + Destination = "-d " + Source = "-s " + DPort = "--dport " + Protocol = "-p " + Jump = "-j " + Reject = "REJECT" + ToDest = "--to-destination " +) + +type Rule map[string]string // no-op implementation of iptables Interface -type fake struct{} - -func NewFake() *fake { - return &fake{} +type FakeIPTables struct { + Lines []byte } -func (*fake) GetVersion() (string, error) { +func NewFake() *FakeIPTables { + return &FakeIPTables{} +} + +func (*FakeIPTables) GetVersion() (string, error) { return "0.0.0", nil } -func (*fake) EnsureChain(table iptables.Table, chain iptables.Chain) (bool, error) { +func (*FakeIPTables) EnsureChain(table iptables.Table, chain iptables.Chain) (bool, error) { return true, nil } -func (*fake) FlushChain(table iptables.Table, chain iptables.Chain) error { +func (*FakeIPTables) FlushChain(table iptables.Table, chain iptables.Chain) error { return nil } -func (*fake) DeleteChain(table iptables.Table, chain iptables.Chain) error { +func (*FakeIPTables) DeleteChain(table iptables.Table, chain iptables.Chain) error { return nil } -func (*fake) EnsureRule(position iptables.RulePosition, table iptables.Table, chain iptables.Chain, args ...string) (bool, error) { +func (*FakeIPTables) EnsureRule(position iptables.RulePosition, table iptables.Table, chain iptables.Chain, args ...string) (bool, error) { return true, nil } -func (*fake) DeleteRule(table iptables.Table, chain iptables.Chain, args ...string) error { +func (*FakeIPTables) DeleteRule(table iptables.Table, chain iptables.Chain, args ...string) error { return nil } -func (*fake) IsIpv6() bool { +func (*FakeIPTables) IsIpv6() bool { return false } -func (*fake) Save(table iptables.Table) ([]byte, error) { +func (*FakeIPTables) Save(table iptables.Table) ([]byte, error) { return make([]byte, 0), nil } -func (*fake) SaveAll() ([]byte, error) { +func (*FakeIPTables) SaveAll() ([]byte, error) { return make([]byte, 0), nil } -func (*fake) Restore(table iptables.Table, data []byte, flush iptables.FlushFlag, counters iptables.RestoreCountersFlag) error { +func (*FakeIPTables) Restore(table iptables.Table, data []byte, flush iptables.FlushFlag, counters iptables.RestoreCountersFlag) error { return nil } -func (*fake) RestoreAll(data []byte, flush iptables.FlushFlag, counters iptables.RestoreCountersFlag) error { +func (f *FakeIPTables) RestoreAll(data []byte, flush iptables.FlushFlag, counters iptables.RestoreCountersFlag) error { + f.Lines = data return nil } -func (*fake) AddReloadFunc(reloadFunc func()) {} +func (*FakeIPTables) AddReloadFunc(reloadFunc func()) {} -func (*fake) Destroy() {} +func (*FakeIPTables) Destroy() {} -var _ = iptables.Interface(&fake{}) +func getToken(line, seperator string) string { + tokens := strings.Split(line, seperator) + if len(tokens) == 2 { + return strings.Split(tokens[1], " ")[0] + } + return "" +} + +// GetChain returns a list of rules for the givne chain. +// The chain name must match exactly. +// The matching is pretty dumb, don't rely on it for anything but testing. +func (f *FakeIPTables) GetRules(chainName string) (rules []Rule) { + for _, l := range strings.Split(string(f.Lines), "\n") { + if strings.Contains(l, fmt.Sprintf("-A %v", chainName)) { + newRule := Rule(map[string]string{}) + for _, arg := range []string{Destination, Source, DPort, Protocol, Jump, ToDest} { + tok := getToken(l, arg) + if tok != "" { + newRule[arg] = tok + } + } + rules = append(rules, newRule) + } + } + return +} + +var _ = iptables.Interface(&FakeIPTables{})