diff --git a/pkg/kubelet/network/hostport/fake_iptables.go b/pkg/kubelet/network/hostport/fake_iptables.go index d8c05baddc..08faa77dd8 100644 --- a/pkg/kubelet/network/hostport/fake_iptables.go +++ b/pkg/kubelet/network/hostport/fake_iptables.go @@ -228,36 +228,27 @@ func saveChain(chain *fakeChain, data *bytes.Buffer) { } func (f *fakeIPTables) Save(tableName utiliptables.Table) ([]byte, error) { + data := bytes.NewBuffer(nil) + err := f.SaveInto(tableName, data) + return data.Bytes(), err +} + +func (f *fakeIPTables) SaveInto(tableName utiliptables.Table, buffer *bytes.Buffer) error { table, err := f.getTable(tableName) if err != nil { - return nil, err + return err } - data := bytes.NewBuffer(nil) - data.WriteString(fmt.Sprintf("*%s\n", table.name)) + buffer.WriteString(fmt.Sprintf("*%s\n", table.name)) rules := bytes.NewBuffer(nil) for _, chain := range table.chains { - data.WriteString(fmt.Sprintf(":%s - [0:0]\n", string(chain.name))) + buffer.WriteString(fmt.Sprintf(":%s - [0:0]\n", string(chain.name))) saveChain(chain, rules) } - data.Write(rules.Bytes()) - data.WriteString("COMMIT\n") - return data.Bytes(), nil -} - -func (f *fakeIPTables) SaveAll() ([]byte, error) { - data := bytes.NewBuffer(nil) - for _, table := range f.tables { - tableData, err := f.Save(table.name) - if err != nil { - return nil, err - } - if _, err = data.Write(tableData); err != nil { - return nil, err - } - } - return data.Bytes(), nil + buffer.Write(rules.Bytes()) + buffer.WriteString("COMMIT\n") + return nil } func (f *fakeIPTables) restore(restoreTableName utiliptables.Table, data []byte, flush utiliptables.FlushFlag) error { diff --git a/pkg/kubelet/prober/prober.go b/pkg/kubelet/prober/prober.go index acd70bb1ae..827c1dab89 100644 --- a/pkg/kubelet/prober/prober.go +++ b/pkg/kubelet/prober/prober.go @@ -237,6 +237,10 @@ func (pb *prober) newExecInContainer(container v1.Container, containerID kubecon }} } +func (eic execInContainer) Run() error { + return fmt.Errorf("unimplemented") +} + func (eic execInContainer) CombinedOutput() ([]byte, error) { return eic.run() } @@ -257,6 +261,10 @@ func (eic execInContainer) SetStdout(out io.Writer) { //unimplemented } +func (eic execInContainer) SetStderr(out io.Writer) { + //unimplemented +} + func (eic execInContainer) Stop() { //unimplemented } diff --git a/pkg/probe/exec/exec_test.go b/pkg/probe/exec/exec_test.go index bd86777d9e..3621983059 100644 --- a/pkg/probe/exec/exec_test.go +++ b/pkg/probe/exec/exec_test.go @@ -30,6 +30,10 @@ type FakeCmd struct { err error } +func (f *FakeCmd) Run() error { + return nil +} + func (f *FakeCmd) CombinedOutput() ([]byte, error) { return f.out, f.err } @@ -44,6 +48,8 @@ func (f *FakeCmd) SetStdin(in io.Reader) {} func (f *FakeCmd) SetStdout(out io.Writer) {} +func (f *FakeCmd) SetStderr(out io.Writer) {} + func (f *FakeCmd) Stop() {} type fakeExitError struct { diff --git a/pkg/proxy/iptables/proxier.go b/pkg/proxy/iptables/proxier.go index c6aedbf144..1cbc78db5b 100644 --- a/pkg/proxy/iptables/proxier.go +++ b/pkg/proxy/iptables/proxier.go @@ -302,6 +302,14 @@ type Proxier struct { recorder record.EventRecorder healthChecker healthcheck.Server healthzServer healthcheck.HealthzUpdater + + // The following buffers are used to reuse memory and avoid allocations + // that are significantly impacting performance. + iptablesData *bytes.Buffer + filterChains *bytes.Buffer + filterRules *bytes.Buffer + natChains *bytes.Buffer + natRules *bytes.Buffer } type localPort struct { @@ -417,6 +425,11 @@ func NewProxier(ipt utiliptables.Interface, recorder: recorder, healthChecker: healthChecker, healthzServer: healthzServer, + iptablesData: bytes.NewBuffer(nil), + filterChains: bytes.NewBuffer(nil), + filterRules: bytes.NewBuffer(nil), + natChains: bytes.NewBuffer(nil), + natRules: bytes.NewBuffer(nil), }, nil } @@ -976,62 +989,66 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { // Get iptables-save output so we can check for existing chains and rules. // This will be a map of chain name to chain with rules as stored in iptables-save/iptables-restore existingFilterChains := make(map[utiliptables.Chain]string) - iptablesSaveRaw, err := proxier.iptables.Save(utiliptables.TableFilter) + proxier.iptablesData.Reset() + err := proxier.iptables.SaveInto(utiliptables.TableFilter, proxier.iptablesData) if err != nil { // if we failed to get any rules glog.Errorf("Failed to execute iptables-save, syncing all rules: %v", err) } else { // otherwise parse the output - existingFilterChains = utiliptables.GetChainLines(utiliptables.TableFilter, iptablesSaveRaw) + existingFilterChains = utiliptables.GetChainLines(utiliptables.TableFilter, proxier.iptablesData.Bytes()) } existingNATChains := make(map[utiliptables.Chain]string) - iptablesSaveRaw, err = proxier.iptables.Save(utiliptables.TableNAT) + proxier.iptablesData.Reset() + err = proxier.iptables.SaveInto(utiliptables.TableNAT, proxier.iptablesData) if err != nil { // if we failed to get any rules glog.Errorf("Failed to execute iptables-save, syncing all rules: %v", err) } else { // otherwise parse the output - existingNATChains = utiliptables.GetChainLines(utiliptables.TableNAT, iptablesSaveRaw) + existingNATChains = utiliptables.GetChainLines(utiliptables.TableNAT, proxier.iptablesData.Bytes()) } - filterChains := bytes.NewBuffer(nil) - filterRules := bytes.NewBuffer(nil) - natChains := bytes.NewBuffer(nil) - natRules := bytes.NewBuffer(nil) + // Reset all buffers used later. + // This is to avoid memory reallocations and thus improve performance. + proxier.filterChains.Reset() + proxier.filterRules.Reset() + proxier.natChains.Reset() + proxier.natRules.Reset() // Write table headers. - writeLine(filterChains, "*filter") - writeLine(natChains, "*nat") + writeLine(proxier.filterChains, "*filter") + writeLine(proxier.natChains, "*nat") // Make sure we keep stats for the top-level chains, if they existed // (which most should have because we created them above). if chain, ok := existingFilterChains[kubeServicesChain]; ok { - writeLine(filterChains, chain) + writeLine(proxier.filterChains, chain) } else { - writeLine(filterChains, utiliptables.MakeChainLine(kubeServicesChain)) + writeLine(proxier.filterChains, utiliptables.MakeChainLine(kubeServicesChain)) } if chain, ok := existingNATChains[kubeServicesChain]; ok { - writeLine(natChains, chain) + writeLine(proxier.natChains, chain) } else { - writeLine(natChains, utiliptables.MakeChainLine(kubeServicesChain)) + writeLine(proxier.natChains, utiliptables.MakeChainLine(kubeServicesChain)) } if chain, ok := existingNATChains[kubeNodePortsChain]; ok { - writeLine(natChains, chain) + writeLine(proxier.natChains, chain) } else { - writeLine(natChains, utiliptables.MakeChainLine(kubeNodePortsChain)) + writeLine(proxier.natChains, utiliptables.MakeChainLine(kubeNodePortsChain)) } if chain, ok := existingNATChains[kubePostroutingChain]; ok { - writeLine(natChains, chain) + writeLine(proxier.natChains, chain) } else { - writeLine(natChains, utiliptables.MakeChainLine(kubePostroutingChain)) + writeLine(proxier.natChains, utiliptables.MakeChainLine(kubePostroutingChain)) } if chain, ok := existingNATChains[KubeMarkMasqChain]; ok { - writeLine(natChains, chain) + writeLine(proxier.natChains, chain) } else { - writeLine(natChains, utiliptables.MakeChainLine(KubeMarkMasqChain)) + writeLine(proxier.natChains, utiliptables.MakeChainLine(KubeMarkMasqChain)) } // Install the kubernetes-specific postrouting rules. We use a whole chain for // this so that it is easier to flush and change, for example if the mark // value should ever change. - writeLine(natRules, []string{ + writeLine(proxier.natRules, []string{ "-A", string(kubePostroutingChain), "-m", "comment", "--comment", `"kubernetes service traffic requiring SNAT"`, "-m", "mark", "--mark", proxier.masqueradeMark, @@ -1041,7 +1058,7 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { // Install the kubernetes-specific masquerade mark rule. We use a whole chain for // this so that it is easier to flush and change, for example if the mark // value should ever change. - writeLine(natRules, []string{ + writeLine(proxier.natRules, []string{ "-A", string(KubeMarkMasqChain), "-j", "MARK", "--set-xmark", proxier.masqueradeMark, }...) @@ -1062,9 +1079,9 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { // Create the per-service chain, retaining counters if possible. svcChain := servicePortChainName(svcNameString, protocol) if chain, ok := existingNATChains[svcChain]; ok { - writeLine(natChains, chain) + writeLine(proxier.natChains, chain) } else { - writeLine(natChains, utiliptables.MakeChainLine(svcChain)) + writeLine(proxier.natChains, utiliptables.MakeChainLine(svcChain)) } activeNATChains[svcChain] = true @@ -1073,9 +1090,9 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { // Only for services request OnlyLocal traffic // create the per-service LB chain, retaining counters if possible. if lbChain, ok := existingNATChains[svcXlbChain]; ok { - writeLine(natChains, lbChain) + writeLine(proxier.natChains, lbChain) } else { - writeLine(natChains, utiliptables.MakeChainLine(svcXlbChain)) + writeLine(proxier.natChains, utiliptables.MakeChainLine(svcXlbChain)) } activeNATChains[svcXlbChain] = true } else if activeNATChains[svcXlbChain] { @@ -1092,12 +1109,12 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { "--dport", fmt.Sprintf("%d", svcInfo.port), } if proxier.masqueradeAll { - writeLine(natRules, append(args, "-j", string(KubeMarkMasqChain))...) + writeLine(proxier.natRules, append(args, "-j", string(KubeMarkMasqChain))...) } if len(proxier.clusterCIDR) > 0 { - writeLine(natRules, append(args, "! -s", proxier.clusterCIDR, "-j", string(KubeMarkMasqChain))...) + writeLine(proxier.natRules, append(args, "! -s", proxier.clusterCIDR, "-j", string(KubeMarkMasqChain))...) } - writeLine(natRules, append(args, "-j", string(svcChain))...) + writeLine(proxier.natRules, append(args, "-j", string(svcChain))...) // Capture externalIPs. for _, externalIP := range svcInfo.externalIPs { @@ -1142,7 +1159,7 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { "--dport", fmt.Sprintf("%d", svcInfo.port), } // We have to SNAT packets to external IPs. - writeLine(natRules, append(args, "-j", string(KubeMarkMasqChain))...) + writeLine(proxier.natRules, append(args, "-j", string(KubeMarkMasqChain))...) // Allow traffic for external IPs that does not come from a bridge (i.e. not from a container) // nor from a local process to be forwarded to the service. @@ -1151,16 +1168,16 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { externalTrafficOnlyArgs := append(args, "-m", "physdev", "!", "--physdev-is-in", "-m", "addrtype", "!", "--src-type", "LOCAL") - writeLine(natRules, append(externalTrafficOnlyArgs, "-j", string(svcChain))...) + writeLine(proxier.natRules, append(externalTrafficOnlyArgs, "-j", string(svcChain))...) dstLocalOnlyArgs := append(args, "-m", "addrtype", "--dst-type", "LOCAL") // Allow traffic bound for external IPs that happen to be recognized as local IPs to stay local. // This covers cases like GCE load-balancers which get added to the local routing table. - writeLine(natRules, append(dstLocalOnlyArgs, "-j", string(svcChain))...) + writeLine(proxier.natRules, append(dstLocalOnlyArgs, "-j", string(svcChain))...) // If the service has no endpoints then reject packets coming via externalIP // Install ICMP Reject rule in filter table for destination=externalIP and dport=svcport if len(proxier.endpointsMap[svcName]) == 0 { - writeLine(filterRules, + writeLine(proxier.filterRules, "-A", string(kubeServicesChain), "-m", "comment", "--comment", fmt.Sprintf(`"%s has no endpoints"`, svcNameString), "-m", protocol, "-p", protocol, @@ -1177,9 +1194,9 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { // create service firewall chain fwChain := serviceFirewallChainName(svcNameString, protocol) if chain, ok := existingNATChains[fwChain]; ok { - writeLine(natChains, chain) + writeLine(proxier.natChains, chain) } else { - writeLine(natChains, utiliptables.MakeChainLine(fwChain)) + writeLine(proxier.natChains, utiliptables.MakeChainLine(fwChain)) } activeNATChains[fwChain] = true // The service firewall rules are created based on ServiceSpec.loadBalancerSourceRanges field. @@ -1194,7 +1211,7 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { "--dport", fmt.Sprintf("%d", svcInfo.port), } // jump to service firewall chain - writeLine(natRules, append(args, "-j", string(fwChain))...) + writeLine(proxier.natRules, append(args, "-j", string(fwChain))...) args = []string{ "-A", string(fwChain), @@ -1206,18 +1223,18 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { // If we are proxying globally, we need to masquerade in case we cross nodes. // If we are proxying only locally, we can retain the source IP. if !svcInfo.onlyNodeLocalEndpoints { - writeLine(natRules, append(args, "-j", string(KubeMarkMasqChain))...) + writeLine(proxier.natRules, append(args, "-j", string(KubeMarkMasqChain))...) chosenChain = svcChain } if len(svcInfo.loadBalancerSourceRanges) == 0 { // allow all sources, so jump directly to the KUBE-SVC or KUBE-XLB chain - writeLine(natRules, append(args, "-j", string(chosenChain))...) + writeLine(proxier.natRules, append(args, "-j", string(chosenChain))...) } else { // firewall filter based on each source range allowFromNode := false for _, src := range svcInfo.loadBalancerSourceRanges { - writeLine(natRules, append(args, "-s", src, "-j", string(chosenChain))...) + writeLine(proxier.natRules, append(args, "-s", src, "-j", string(chosenChain))...) // ignore error because it has been validated _, cidr, _ := net.ParseCIDR(src) if cidr.Contains(proxier.nodeIP) { @@ -1228,13 +1245,13 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { // loadbalancer's backend hosts. In this case, request will not hit the loadbalancer but loop back directly. // Need to add the following rule to allow request on host. if allowFromNode { - writeLine(natRules, append(args, "-s", fmt.Sprintf("%s/32", ingress.IP), "-j", string(chosenChain))...) + writeLine(proxier.natRules, append(args, "-s", fmt.Sprintf("%s/32", ingress.IP), "-j", string(chosenChain))...) } } // If the packet was able to reach the end of firewall chain, then it did not get DNATed. // It means the packet cannot go thru the firewall, then mark it for DROP - writeLine(natRules, append(args, "-j", string(KubeMarkDropChain))...) + writeLine(proxier.natRules, append(args, "-j", string(KubeMarkDropChain))...) } } @@ -1273,13 +1290,13 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { } if !svcInfo.onlyNodeLocalEndpoints { // Nodeports need SNAT, unless they're local. - writeLine(natRules, append(args, "-j", string(KubeMarkMasqChain))...) + writeLine(proxier.natRules, append(args, "-j", string(KubeMarkMasqChain))...) // Jump to the service chain. - writeLine(natRules, append(args, "-j", string(svcChain))...) + writeLine(proxier.natRules, append(args, "-j", string(svcChain))...) } else { // TODO: Make all nodePorts jump to the firewall chain. // Currently we only create it for loadbalancers (#33586). - writeLine(natRules, append(args, "-j", string(svcXlbChain))...) + writeLine(proxier.natRules, append(args, "-j", string(svcXlbChain))...) } // If the service has no endpoints then reject packets. The filter @@ -1287,7 +1304,7 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { // the nat table does, so we just stick this into the kube-services // chain. if len(proxier.endpointsMap[svcName]) == 0 { - writeLine(filterRules, + writeLine(proxier.filterRules, "-A", string(kubeServicesChain), "-m", "comment", "--comment", fmt.Sprintf(`"%s has no endpoints"`, svcNameString), "-m", "addrtype", "--dst-type", "LOCAL", @@ -1300,7 +1317,7 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { // If the service has no endpoints then reject packets. if len(proxier.endpointsMap[svcName]) == 0 { - writeLine(filterRules, + writeLine(proxier.filterRules, "-A", string(kubeServicesChain), "-m", "comment", "--comment", fmt.Sprintf(`"%s has no endpoints"`, svcNameString), "-m", protocol, "-p", protocol, @@ -1325,9 +1342,9 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { // Create the endpoint chain, retaining counters if possible. if chain, ok := existingNATChains[utiliptables.Chain(endpointChain)]; ok { - writeLine(natChains, chain) + writeLine(proxier.natChains, chain) } else { - writeLine(natChains, utiliptables.MakeChainLine(endpointChain)) + writeLine(proxier.natChains, utiliptables.MakeChainLine(endpointChain)) } activeNATChains[endpointChain] = true } @@ -1335,7 +1352,7 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { // First write session affinity rules, if applicable. if svcInfo.sessionAffinityType == api.ServiceAffinityClientIP { for _, endpointChain := range endpointChains { - writeLine(natRules, + writeLine(proxier.natRules, "-A", string(svcChain), "-m", "comment", "--comment", svcNameString, "-m", "recent", "--name", string(endpointChain), @@ -1361,7 +1378,7 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { } // The final (or only if n == 1) rule is a guaranteed match. args = append(args, "-j", string(endpointChain)) - writeLine(natRules, args...) + writeLine(proxier.natRules, args...) // Rules in the per-endpoint chain. args = []string{ @@ -1369,7 +1386,7 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { "-m", "comment", "--comment", svcNameString, } // Handle traffic that loops back to the originator with SNAT. - writeLine(natRules, append(args, + writeLine(proxier.natRules, append(args, "-s", fmt.Sprintf("%s/32", strings.Split(endpoints[i].endpoint, ":")[0]), "-j", string(KubeMarkMasqChain))...) // Update client-affinity lists. @@ -1378,7 +1395,7 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { } // DNAT to final destination. args = append(args, "-m", protocol, "-p", protocol, "-j", "DNAT", "--to-destination", endpoints[i].endpoint) - writeLine(natRules, args...) + writeLine(proxier.natRules, args...) } // The logic below this applies only if this service is marked as OnlyLocal @@ -1408,7 +1425,7 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { "-s", proxier.clusterCIDR, "-j", string(svcChain), } - writeLine(natRules, args...) + writeLine(proxier.natRules, args...) } numLocalEndpoints := len(localEndpointChains) @@ -1421,7 +1438,7 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { "-j", string(KubeMarkDropChain), } - writeLine(natRules, args...) + writeLine(proxier.natRules, args...) } else { // Setup probability filter rules only over local endpoints for i, endpointChain := range localEndpointChains { @@ -1440,7 +1457,7 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { } // The final (or only if n == 1) rule is a guaranteed match. args = append(args, "-j", string(endpointChain)) - writeLine(natRules, args...) + writeLine(proxier.natRules, args...) } } } @@ -1456,33 +1473,37 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) { // We must (as per iptables) write a chain-line for it, which has // the nice effect of flushing the chain. Then we can remove the // chain. - writeLine(natChains, existingNATChains[chain]) - writeLine(natRules, "-X", chainString) + writeLine(proxier.natChains, existingNATChains[chain]) + writeLine(proxier.natRules, "-X", chainString) } } // Finally, tail-call to the nodeports chain. This needs to be after all // other service portal rules. - writeLine(natRules, + writeLine(proxier.natRules, "-A", string(kubeServicesChain), "-m", "comment", "--comment", `"kubernetes service nodeports; NOTE: this must be the last rule in this chain"`, "-m", "addrtype", "--dst-type", "LOCAL", "-j", string(kubeNodePortsChain)) // Write the end-of-table markers. - writeLine(filterRules, "COMMIT") - writeLine(natRules, "COMMIT") + writeLine(proxier.filterRules, "COMMIT") + writeLine(proxier.natRules, "COMMIT") // Sync rules. - // NOTE: NoFlushTables is used so we don't flush non-kubernetes chains in the table. - filterLines := append(filterChains.Bytes(), filterRules.Bytes()...) - natLines := append(natChains.Bytes(), natRules.Bytes()...) - lines := append(filterLines, natLines...) + // NOTE: NoFlushTables is used so we don't flush non-kubernetes chains in the table + proxier.iptablesData.Reset() + proxier.iptablesData.Write(proxier.filterChains.Bytes()) + proxier.iptablesData.Write(proxier.filterRules.Bytes()) + proxier.iptablesData.Write(proxier.natChains.Bytes()) + proxier.iptablesData.Write(proxier.natRules.Bytes()) - glog.V(3).Infof("Restoring iptables rules: %s", lines) - err = proxier.iptables.RestoreAll(lines, utiliptables.NoFlushTables, utiliptables.RestoreCounters) + if glog.V(5) { + glog.V(5).Infof("Restoring iptables rules: %s", proxier.iptablesData.Bytes()) + } + err = proxier.iptables.RestoreAll(proxier.iptablesData.Bytes(), utiliptables.NoFlushTables, utiliptables.RestoreCounters) if err != nil { - glog.Errorf("Failed to execute iptables-restore: %v\nRules:\n%s", err, lines) + glog.Errorf("Failed to execute iptables-restore: %v\nRules:\n%s", err, proxier.iptablesData.Bytes()) // Revert new local ports. revertPorts(replacementPortsMap, proxier.portsMap) return @@ -1536,7 +1557,15 @@ func (proxier *Proxier) clearUDPConntrackForPort(port int) { // Join all words with spaces, terminate with newline and write to buf. func writeLine(buf *bytes.Buffer, words ...string) { - buf.WriteString(strings.Join(words, " ") + "\n") + // We avoid strings.Join for performance reasons. + for i := range words { + buf.WriteString(words[i]) + if i < len(words)-1 { + buf.WriteByte(' ') + } else { + buf.WriteByte('\n') + } + } } func isLocalIP(ip string) (bool, error) { diff --git a/pkg/proxy/iptables/proxier_test.go b/pkg/proxy/iptables/proxier_test.go index b49b1929b8..aba2e3c827 100644 --- a/pkg/proxy/iptables/proxier_test.go +++ b/pkg/proxy/iptables/proxier_test.go @@ -17,6 +17,7 @@ limitations under the License. package iptables import ( + "bytes" "reflect" "strconv" "testing" @@ -394,6 +395,11 @@ func NewFakeProxier(ipt utiliptables.Interface) *Proxier { portsMap: make(map[localPort]closeable), portMapper: &fakePortOpener{[]*localPort{}}, healthChecker: newFakeHealthChecker(), + iptablesData: bytes.NewBuffer(nil), + filterChains: bytes.NewBuffer(nil), + filterRules: bytes.NewBuffer(nil), + natChains: bytes.NewBuffer(nil), + natRules: bytes.NewBuffer(nil), } } diff --git a/pkg/util/exec/exec.go b/pkg/util/exec/exec.go index 327ddf5bce..f43bfa7a17 100644 --- a/pkg/util/exec/exec.go +++ b/pkg/util/exec/exec.go @@ -41,6 +41,8 @@ type Interface interface { // As more functionality is needed, this can grow. Since Cmd is a struct, we will have // to replace fields with get/set method pairs. type Cmd interface { + // Run runs the command to the completion. + Run() error // CombinedOutput runs the command and returns its combined standard output // and standard error. This follows the pattern of package os/exec. CombinedOutput() ([]byte, error) @@ -49,6 +51,7 @@ type Cmd interface { SetDir(dir string) SetStdin(in io.Reader) SetStdout(out io.Writer) + SetStderr(out io.Writer) // Stops the command by sending SIGTERM. It is not guaranteed the // process will stop before this function returns. If the process is not // responding, an internal timer function will send a SIGKILL to force @@ -99,6 +102,15 @@ func (cmd *cmdWrapper) SetStdout(out io.Writer) { cmd.Stdout = out } +func (cmd *cmdWrapper) SetStderr(out io.Writer) { + cmd.Stderr = out +} + +// Run is part of the Cmd interface. +func (cmd *cmdWrapper) Run() error { + return (*osexec.Cmd)(cmd).Run() +} + // CombinedOutput is part of the Cmd interface. func (cmd *cmdWrapper) CombinedOutput() ([]byte, error) { out, err := (*osexec.Cmd)(cmd).CombinedOutput() diff --git a/pkg/util/exec/fake_exec.go b/pkg/util/exec/fake_exec.go index b87265099a..e3741dca42 100644 --- a/pkg/util/exec/fake_exec.go +++ b/pkg/util/exec/fake_exec.go @@ -52,6 +52,7 @@ type FakeCmd struct { Dirs []string Stdin io.Reader Stdout io.Writer + Stderr io.Writer } func InitFakeCmd(fake *FakeCmd, cmd string, args ...string) Cmd { @@ -73,6 +74,14 @@ func (fake *FakeCmd) SetStdout(out io.Writer) { fake.Stdout = out } +func (fake *FakeCmd) SetStderr(out io.Writer) { + fake.Stderr = out +} + +func (fake *FakeCmd) Run() error { + return fmt.Errorf("unimplemented") +} + func (fake *FakeCmd) CombinedOutput() ([]byte, error) { if fake.CombinedOutputCalls > len(fake.CombinedOutputScript)-1 { panic("ran out of CombinedOutput() actions") diff --git a/pkg/util/iptables/iptables.go b/pkg/util/iptables/iptables.go index 2d3c24dcac..57edf5f1e5 100644 --- a/pkg/util/iptables/iptables.go +++ b/pkg/util/iptables/iptables.go @@ -56,8 +56,8 @@ type Interface interface { IsIpv6() bool // Save calls `iptables-save` for table. Save(table Table) ([]byte, error) - // SaveAll calls `iptables-save`. - SaveAll() ([]byte, error) + // SaveInto calls `iptables-save` for table and stores result in a given buffer. + SaveInto(table Table, buffer *bytes.Buffer) error // Restore runs `iptables-restore` passing data through []byte. // table is the Table to restore // data should be formatted like the output of Save() @@ -317,14 +317,23 @@ func (runner *runner) Save(table Table) ([]byte, error) { return runner.exec.Command(cmdIPTablesSave, args...).CombinedOutput() } -// SaveAll is part of Interface. -func (runner *runner) SaveAll() ([]byte, error) { +// SaveInto is part of Interface. +func (runner *runner) SaveInto(table Table, buffer *bytes.Buffer) error { runner.mu.Lock() defer runner.mu.Unlock() // run and return - glog.V(4).Infof("running iptables-save") - return runner.exec.Command(cmdIPTablesSave, []string{}...).CombinedOutput() + args := []string{"-t", string(table)} + glog.V(4).Infof("running iptables-save %v", args) + cmd := runner.exec.Command(cmdIPTablesSave, args...) + // Since CombinedOutput() doesn't support redirecting it to a buffer, + // we need to workaround it by redirecting stdout and stderr to buffer + // and explicitly calling Run() [CombinedOutput() underneath itself + // creates a new buffer, redirects stdout and stderr to it and also + // calls Run()]. + cmd.SetStdout(buffer) + cmd.SetStderr(buffer) + return cmd.Run() } // Restore is part of Interface. @@ -393,7 +402,7 @@ func (runner *runner) run(op operation, args []string) ([]byte, error) { fullArgs := append(runner.waitFlag, string(op)) fullArgs = append(fullArgs, args...) - glog.V(4).Infof("running iptables %s %v", string(op), args) + glog.V(5).Infof("running iptables %s %v", string(op), args) return runner.exec.Command(iptablesCmd, fullArgs...).CombinedOutput() // Don't log err here - callers might not think it is an error. } diff --git a/pkg/util/iptables/iptables_test.go b/pkg/util/iptables/iptables_test.go index 5fb921d522..62b8416709 100644 --- a/pkg/util/iptables/iptables_test.go +++ b/pkg/util/iptables/iptables_test.go @@ -884,59 +884,6 @@ COMMIT } } -func TestSaveAll(t *testing.T) { - output := `# Generated by iptables-save v1.6.0 on Thu Jan 19 11:38:09 2017 -*filter -:INPUT ACCEPT [15079:38410730] -:FORWARD ACCEPT [0:0] -:OUTPUT ACCEPT [11045:521562] -COMMIT -# Completed on Thu Jan 19 11:38:09 2017` - - fcmd := exec.FakeCmd{ - CombinedOutputScript: []exec.FakeCombinedOutputAction{ - // iptables version check - func() ([]byte, error) { return []byte("iptables v1.9.22"), nil }, - // iptables-restore version check - func() ([]byte, error) { return []byte("iptables-restore v1.9.22"), nil }, - func() ([]byte, error) { return []byte(output), nil }, - func() ([]byte, error) { return nil, &exec.FakeExitError{Status: 1} }, - }, - } - fexec := exec.FakeExec{ - CommandScript: []exec.FakeCommandAction{ - func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) }, - func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) }, - func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) }, - func(cmd string, args ...string) exec.Cmd { return exec.InitFakeCmd(&fcmd, cmd, args...) }, - }, - } - runner := New(&fexec, dbus.NewFake(nil, nil), ProtocolIpv4) - defer runner.Destroy() - // Success. - o, err := runner.SaveAll() - if err != nil { - t.Fatalf("expected success, got %v", err) - } - - if string(o[:len(output)]) != output { - t.Errorf("expected output to be equal to mocked one, got %v", o) - } - - if fcmd.CombinedOutputCalls != 3 { - t.Errorf("expected 3 CombinedOutput() calls, got %d", fcmd.CombinedOutputCalls) - } - if !sets.NewString(fcmd.CombinedOutputLog[2]...).HasAll("iptables-save") { - t.Errorf("wrong CombinedOutput() log, got %s", fcmd.CombinedOutputLog[2]) - } - - // Failure. - _, err = runner.SaveAll() - if err == nil { - t.Errorf("expected failure") - } -} - func TestRestore(t *testing.T) { fcmd := exec.FakeCmd{ CombinedOutputScript: []exec.FakeCombinedOutputAction{ diff --git a/pkg/util/iptables/testing/fake.go b/pkg/util/iptables/testing/fake.go index 16cd90ba30..8d9ac7c070 100644 --- a/pkg/util/iptables/testing/fake.go +++ b/pkg/util/iptables/testing/fake.go @@ -17,6 +17,7 @@ limitations under the License. package testing import ( + "bytes" "fmt" "strings" @@ -78,8 +79,9 @@ func (f *FakeIPTables) Save(table iptables.Table) ([]byte, error) { return lines, nil } -func (*FakeIPTables) SaveAll() ([]byte, error) { - return make([]byte, 0), nil +func (f *FakeIPTables) SaveInto(table iptables.Table, buffer *bytes.Buffer) error { + buffer.Write(f.Lines) + return nil } func (*FakeIPTables) Restore(table iptables.Table, data []byte, flush iptables.FlushFlag, counters iptables.RestoreCountersFlag) error {