diff --git a/pkg/proxy/iptables/proxier_test.go b/pkg/proxy/iptables/proxier_test.go index d8c079f135..f742bc3a5c 100644 --- a/pkg/proxy/iptables/proxier_test.go +++ b/pkg/proxy/iptables/proxier_test.go @@ -257,11 +257,12 @@ func TestExecConntrackTool(t *testing.T) { } } -func newFakeServiceInfo(service proxy.ServicePortName, ip net.IP, protocol api.Protocol, onlyNodeLocalEndpoints bool) *serviceInfo { +func newFakeServiceInfo(service proxy.ServicePortName, ip net.IP, port int, protocol api.Protocol, onlyNodeLocalEndpoints bool) *serviceInfo { return &serviceInfo{ sessionAffinityType: api.ServiceAffinityNone, // default stickyMaxAgeMinutes: 180, // TODO: paramaterize this in the API. clusterIP: ip, + port: port, protocol: protocol, onlyNodeLocalEndpoints: onlyNodeLocalEndpoints, } @@ -285,10 +286,10 @@ func TestDeleteEndpointConnections(t *testing.T) { } serviceMap := make(map[proxy.ServicePortName]*serviceInfo) - svc1 := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: "svc1"}, Port: ""} - svc2 := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: "svc2"}, Port: ""} - serviceMap[svc1] = newFakeServiceInfo(svc1, net.IPv4(10, 20, 30, 40), api.ProtocolUDP, false) - serviceMap[svc2] = newFakeServiceInfo(svc1, net.IPv4(10, 20, 30, 41), api.ProtocolTCP, false) + svc1 := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: "svc1"}, Port: "80"} + svc2 := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: "svc2"}, Port: "80"} + serviceMap[svc1] = newFakeServiceInfo(svc1, net.IPv4(10, 20, 30, 40), 80, api.ProtocolUDP, false) + serviceMap[svc2] = newFakeServiceInfo(svc1, net.IPv4(10, 20, 30, 41), 80, api.ProtocolTCP, false) fakeProxier := Proxier{exec: &fexec, serviceMap: serviceMap} @@ -505,18 +506,146 @@ func NewFakeProxier(ipt utiliptables.Interface) *Proxier { } func hasJump(rules []iptablestest.Rule, destChain, destIP, destPort string) bool { + match := false for _, r := range rules { if r[iptablestest.Jump] == destChain { + match = true if destIP != "" { - return strings.Contains(r[iptablestest.Destination], destIP) + if strings.Contains(r[iptablestest.Destination], destIP) && (strings.Contains(r[iptablestest.DPort], destPort) || r[iptablestest.DPort] == "") { + return true + } + match = false } if destPort != "" { - return strings.Contains(r[iptablestest.DPort], destPort) + if strings.Contains(r[iptablestest.DPort], destPort) && (strings.Contains(r[iptablestest.Destination], destIP) || r[iptablestest.Destination] == "") { + return true + } + match = false } - return true } } - return false + return match +} + +func TestHasJump(t *testing.T) { + testCases := map[string]struct { + rules []iptablestest.Rule + destChain string + destIP string + destPort string + expected bool + }{ + "case 1": { + // Match the 1st rule(both dest IP and dest Port) + rules: []iptablestest.Rule{ + {"-d ": "10.20.30.41/32", "--dport ": "80", "-p ": "tcp", "-j ": "REJECT"}, + {"--dport ": "3001", "-p ": "tcp", "-j ": "KUBE-MARK-MASQ"}, + }, + destChain: "REJECT", + destIP: "10.20.30.41", + destPort: "80", + expected: true, + }, + "case 2": { + // Match the 2nd rule(dest Port) + rules: []iptablestest.Rule{ + {"-d ": "10.20.30.41/32", "-p ": "tcp", "-j ": "REJECT"}, + {"--dport ": "3001", "-p ": "tcp", "-j ": "REJECT"}, + }, + destChain: "REJECT", + destIP: "", + destPort: "3001", + expected: true, + }, + "case 3": { + // Match both dest IP and dest Port + rules: []iptablestest.Rule{ + {"-d ": "1.2.3.4/32", "--dport ": "80", "-p ": "tcp", "-j ": "KUBE-XLB-GF53O3C2HZEXL2XN"}, + }, + destChain: "KUBE-XLB-GF53O3C2HZEXL2XN", + destIP: "1.2.3.4", + destPort: "80", + expected: true, + }, + "case 4": { + // Match dest IP but doesn't match dest Port + rules: []iptablestest.Rule{ + {"-d ": "1.2.3.4/32", "--dport ": "80", "-p ": "tcp", "-j ": "KUBE-XLB-GF53O3C2HZEXL2XN"}, + }, + destChain: "KUBE-XLB-GF53O3C2HZEXL2XN", + destIP: "1.2.3.4", + destPort: "8080", + expected: false, + }, + "case 5": { + // Match dest Port but doesn't match dest IP + rules: []iptablestest.Rule{ + {"-d ": "1.2.3.4/32", "--dport ": "80", "-p ": "tcp", "-j ": "KUBE-XLB-GF53O3C2HZEXL2XN"}, + }, + destChain: "KUBE-XLB-GF53O3C2HZEXL2XN", + destIP: "10.20.30.40", + destPort: "80", + expected: false, + }, + "case 6": { + // Match the 2nd rule(dest IP) + rules: []iptablestest.Rule{ + {"-d ": "10.20.30.41/32", "-p ": "tcp", "-j ": "REJECT"}, + {"-d ": "1.2.3.4/32", "-p ": "tcp", "-j ": "REJECT"}, + {"--dport ": "3001", "-p ": "tcp", "-j ": "REJECT"}, + }, + destChain: "REJECT", + destIP: "1.2.3.4", + destPort: "8080", + expected: true, + }, + "case 7": { + // Match the 2nd rule(dest Port) + rules: []iptablestest.Rule{ + {"-d ": "10.20.30.41/32", "-p ": "tcp", "-j ": "REJECT"}, + {"--dport ": "3001", "-p ": "tcp", "-j ": "REJECT"}, + }, + destChain: "REJECT", + destIP: "1.2.3.4", + destPort: "3001", + expected: true, + }, + "case 8": { + // Match the 1st rule(dest IP) + rules: []iptablestest.Rule{ + {"-d ": "10.20.30.41/32", "-p ": "tcp", "-j ": "REJECT"}, + {"--dport ": "3001", "-p ": "tcp", "-j ": "REJECT"}, + }, + destChain: "REJECT", + destIP: "10.20.30.41", + destPort: "8080", + expected: true, + }, + "case 9": { + rules: []iptablestest.Rule{ + {"-j ": "KUBE-SEP-LWSOSDSHMKPJHHJV"}, + }, + destChain: "KUBE-SEP-LWSOSDSHMKPJHHJV", + destIP: "", + destPort: "", + expected: true, + }, + "case 10": { + rules: []iptablestest.Rule{ + {"-j ": "KUBE-SEP-FOO"}, + }, + destChain: "KUBE-SEP-BAR", + destIP: "", + destPort: "", + expected: false, + }, + } + + for k, tc := range testCases { + if got := hasJump(tc.rules, tc.destChain, tc.destIP, tc.destPort); got != tc.expected { + t.Errorf("%v: expected %v, got %v", k, tc.expected, got) + } + } } func hasDNAT(rules []iptablestest.Rule, endpoint string) bool { @@ -541,8 +670,8 @@ func TestClusterIPReject(t *testing.T) { svcName := "svc1" svcIP := net.IPv4(10, 20, 30, 41) - svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: ""} - fp.serviceMap[svc] = newFakeServiceInfo(svc, svcIP, api.ProtocolTCP, false) + svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: "80"} + fp.serviceMap[svc] = newFakeServiceInfo(svc, svcIP, 80, api.ProtocolTCP, false) fp.syncProxyRules() svcChain := string(servicePortChainName(svc, strings.ToLower(string(api.ProtocolTCP)))) @@ -551,7 +680,7 @@ func TestClusterIPReject(t *testing.T) { errorf(fmt.Sprintf("Unexpected rule for chain %v service %v without endpoints", svcChain, svcName), svcRules, t) } kubeSvcRules := ipt.GetRules(string(kubeServicesChain)) - if !hasJump(kubeSvcRules, iptablestest.Reject, svcIP.String(), "") { + if !hasJump(kubeSvcRules, iptablestest.Reject, svcIP.String(), "80") { errorf(fmt.Sprintf("Failed to find a %v rule for service %v with no endpoints", iptablestest.Reject, svcName), kubeSvcRules, t) } } @@ -563,7 +692,7 @@ func TestClusterIPEndpointsJump(t *testing.T) { svcIP := net.IPv4(10, 20, 30, 41) svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: "80"} - fp.serviceMap[svc] = newFakeServiceInfo(svc, svcIP, api.ProtocolTCP, true) + fp.serviceMap[svc] = newFakeServiceInfo(svc, svcIP, 80, api.ProtocolTCP, true) ep := "10.180.0.1:80" fp.endpointsMap[svc] = []*endpointsInfo{{ep, false}} @@ -573,7 +702,7 @@ func TestClusterIPEndpointsJump(t *testing.T) { epChain := string(servicePortEndpointChainName(svc, strings.ToLower(string(api.ProtocolTCP)), ep)) kubeSvcRules := ipt.GetRules(string(kubeServicesChain)) - if !hasJump(kubeSvcRules, svcChain, svcIP.String(), "") { + if !hasJump(kubeSvcRules, svcChain, svcIP.String(), "80") { errorf(fmt.Sprintf("Failed to find jump from KUBE-SERVICES to %v chain", svcChain), kubeSvcRules, t) } @@ -602,7 +731,7 @@ func TestLoadBalancer(t *testing.T) { svcIP := net.IPv4(10, 20, 30, 41) svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: "80"} - svcInfo := newFakeServiceInfo(svc, svcIP, api.ProtocolTCP, false) + svcInfo := newFakeServiceInfo(svc, svcIP, 80, api.ProtocolTCP, false) fp.serviceMap[svc] = typeLoadBalancer(svcInfo) ep1 := "10.180.0.1:80" @@ -616,7 +745,7 @@ func TestLoadBalancer(t *testing.T) { //lbChain := string(serviceLBChainName(svc, proto)) kubeSvcRules := ipt.GetRules(string(kubeServicesChain)) - if !hasJump(kubeSvcRules, fwChain, svcInfo.loadBalancerStatus.Ingress[0].IP, "") { + if !hasJump(kubeSvcRules, fwChain, svcInfo.loadBalancerStatus.Ingress[0].IP, "80") { errorf(fmt.Sprintf("Failed to find jump to firewall chain %v", fwChain), kubeSvcRules, t) } @@ -633,7 +762,7 @@ func TestNodePort(t *testing.T) { svcIP := net.IPv4(10, 20, 30, 41) svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: "80"} - svcInfo := newFakeServiceInfo(svc, svcIP, api.ProtocolTCP, false) + svcInfo := newFakeServiceInfo(svc, svcIP, 80, api.ProtocolTCP, false) svcInfo.nodePort = 3001 fp.serviceMap[svc] = svcInfo @@ -658,7 +787,7 @@ func TestOnlyLocalLoadBalancing(t *testing.T) { svcIP := net.IPv4(10, 20, 30, 41) svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: "80"} - svcInfo := newFakeServiceInfo(svc, svcIP, api.ProtocolTCP, true) + svcInfo := newFakeServiceInfo(svc, svcIP, 80, api.ProtocolTCP, true) fp.serviceMap[svc] = typeLoadBalancer(svcInfo) nonLocalEp := "10.180.0.1:80" @@ -716,7 +845,7 @@ func onlyLocalNodePorts(t *testing.T, fp *Proxier, ipt *iptablestest.FakeIPTable svcIP := net.IPv4(10, 20, 30, 41) svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: "80"} - svcInfo := newFakeServiceInfo(svc, svcIP, api.ProtocolTCP, true) + svcInfo := newFakeServiceInfo(svc, svcIP, 80, api.ProtocolTCP, true) svcInfo.nodePort = 3001 fp.serviceMap[svc] = svcInfo diff --git a/pkg/util/iptables/testing/fake.go b/pkg/util/iptables/testing/fake.go index f3b0be57cf..863f4e0aed 100644 --- a/pkg/util/iptables/testing/fake.go +++ b/pkg/util/iptables/testing/fake.go @@ -100,7 +100,7 @@ func getToken(line, seperator string) string { return "" } -// GetChain returns a list of rules for the givne chain. +// GetChain returns a list of rules for the given chain. // The chain name must match exactly. // The matching is pretty dumb, don't rely on it for anything but testing. func (f *FakeIPTables) GetRules(chainName string) (rules []Rule) {