diff --git a/pkg/proxy/ipvs/proxier_test.go b/pkg/proxy/ipvs/proxier_test.go index 904c5c051f..4a6c294c35 100644 --- a/pkg/proxy/ipvs/proxier_test.go +++ b/pkg/proxy/ipvs/proxier_test.go @@ -21,6 +21,7 @@ import ( "fmt" "net" "reflect" + "strings" "testing" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -894,6 +895,126 @@ func TestOnlyLocalNodePorts(t *testing.T) { t.Errorf("Expect node port type service, got none") } } +func TestLoadBalanceSourceRanges(t *testing.T) { + ipt := iptablestest.NewFake() + ipvs := ipvstest.NewFake() + ipset := ipsettest.NewFake(testIPSetVersion) + fp := NewFakeProxier(ipt, ipvs, ipset, nil) + svcIP := "10.20.30.41" + svcPort := 80 + svcLBIP := "1.2.3.4" + svcLBSource := "10.0.0.0/8" + svcPortName := proxy.ServicePortName{ + NamespacedName: makeNSN("ns1", "svc1"), + Port: "p80", + } + epIP := "10.180.0.1" + + makeServiceMap(fp, + makeTestService(svcPortName.Namespace, svcPortName.Name, func(svc *api.Service) { + svc.Spec.Type = "LoadBalancer" + svc.Spec.ClusterIP = svcIP + svc.Spec.Ports = []api.ServicePort{{ + Name: svcPortName.Port, + Port: int32(svcPort), + Protocol: api.ProtocolTCP, + }} + svc.Status.LoadBalancer.Ingress = []api.LoadBalancerIngress{{ + IP: svcLBIP, + }} + svc.Spec.LoadBalancerSourceRanges = []string{ + svcLBSource, + } + }), + ) + makeEndpointsMap(fp, + makeTestEndpoints(svcPortName.Namespace, svcPortName.Name, func(ept *api.Endpoints) { + ept.Subsets = []api.EndpointSubset{{ + Addresses: []api.EndpointAddress{{ + IP: epIP, + NodeName: strPtr(testHostname), + }}, + Ports: []api.EndpointPort{{ + Name: svcPortName.Port, + Port: int32(svcPort), + }}, + }} + }), + ) + + fp.syncProxyRules() + + // Check ipvs service and destinations + services, err := ipvs.GetVirtualServers() + if err != nil { + t.Errorf("Failed to get ipvs services, err: %v", err) + } + found := false + for _, svc := range services { + fmt.Printf("address: %s:%d, %s", svc.Address.String(), svc.Port, svc.Protocol) + if svc.Address.Equal(net.ParseIP(svcLBIP)) && svc.Port == uint16(svcPort) && svc.Protocol == string(api.ProtocolTCP) { + destinations, _ := ipvs.GetRealServers(svc) + if len(destinations) != 1 { + t.Errorf("Unexpected %d destinations, expect 0 destinations", len(destinations)) + } + for _, ep := range destinations { + if ep.Address.String() == epIP && ep.Port == uint16(svcPort) { + found = true + } + } + } + } + if !found { + t.Errorf("Did not got expected loadbalance service") + } + + // Check ipset entry + expectIPSet := map[string]*utilipset.Entry{ + KubeLoadBalancerSet: { + IP: svcLBIP, + Port: svcPort, + Protocol: strings.ToLower(string(api.ProtocolTCP)), + SetType: utilipset.HashIPPort, + }, + KubeLoadBalancerMasqSet: { + IP: svcLBIP, + Port: svcPort, + Protocol: strings.ToLower(string(api.ProtocolTCP)), + SetType: utilipset.HashIPPort, + }, + KubeLoadBalancerSourceCIDRSet: { + IP: svcLBIP, + Port: svcPort, + Protocol: strings.ToLower(string(api.ProtocolTCP)), + Net: svcLBSource, + SetType: utilipset.HashIPPortNet, + }, + } + for set, entry := range expectIPSet { + ents, err := ipset.ListEntries(set) + if err != nil || len(ents) != 1 { + t.Errorf("Check ipset entries failed for ipset: %q", set) + continue + } + if ents[0] != entry.String() { + t.Errorf("Check ipset entries failed for ipset: %q", set) + } + } + + // Check iptables chain and rules + kubeSvcRules := ipt.GetRules(string(kubeServicesChain)) + kubeFWRules := ipt.GetRules(string(KubeFireWallChain)) + if !hasJump(kubeSvcRules, string(KubeMarkMasqChain), KubeLoadBalancerMasqSet) { + t.Errorf("Didn't find jump from chain %v match set %v to MASQUERADE", kubeServicesChain, KubeLoadBalancerMasqSet) + } + if !hasJump(kubeSvcRules, string(KubeFireWallChain), KubeLoadBalancerSet) { + t.Errorf("Didn't find jump from chain %v match set %v to %v", kubeServicesChain, + KubeLoadBalancerSet, KubeFireWallChain) + } + if !hasJump(kubeFWRules, "ACCEPT", KubeLoadBalancerSourceCIDRSet) { + t.Errorf("Didn't find jump from chain %v match set %v to ACCEPT", kubeServicesChain, KubeLoadBalancerSourceCIDRSet) + } +} func TestOnlyLocalLoadBalancing(t *testing.T) { ipt := iptablestest.NewFake() @@ -2277,3 +2398,19 @@ func Test_syncService(t *testing.T) { } } } + +func hasJump(rules []iptablestest.Rule, destChain, ipSet string) bool { + match := false + for _, r := range rules { + if r[iptablestest.Jump] == destChain { + match = true + if ipSet != "" { + if strings.Contains(r[iptablestest.MatchSet], ipSet) { + return true + } + match = false + } + } + } + return match +} diff --git a/pkg/util/iptables/testing/fake.go b/pkg/util/iptables/testing/fake.go index 3a69efa675..cb504f9047 100644 --- a/pkg/util/iptables/testing/fake.go +++ b/pkg/util/iptables/testing/fake.go @@ -33,6 +33,7 @@ const ( Reject = "REJECT" ToDest = "--to-destination " Recent = "recent " + MatchSet = "--match-set " ) type Rule map[string]string @@ -112,7 +113,7 @@ func (f *FakeIPTables) GetRules(chainName string) (rules []Rule) { for _, l := range strings.Split(string(f.Lines), "\n") { if strings.Contains(l, fmt.Sprintf("-A %v", chainName)) { newRule := Rule(map[string]string{}) - for _, arg := range []string{Destination, Source, DPort, Protocol, Jump, ToDest, Recent} { + for _, arg := range []string{Destination, Source, DPort, Protocol, Jump, ToDest, Recent, MatchSet} { tok := getToken(l, arg) if tok != "" { newRule[arg] = tok