mirror of https://github.com/k3s-io/k3s
233 lines
6.4 KiB
Go
233 lines
6.4 KiB
Go
/*
|
|
Copyright 2015 The Kubernetes Authors.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package tests
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"k8s.io/apimachinery/pkg/types"
|
|
restclient "k8s.io/client-go/rest"
|
|
. "k8s.io/client-go/tools/portforward"
|
|
"k8s.io/client-go/transport/spdy"
|
|
"k8s.io/kubernetes/pkg/kubelet/server/portforward"
|
|
)
|
|
|
|
// fakePortForwarder simulates port forwarding for testing. It implements
|
|
// portforward.PortForwarder.
|
|
type fakePortForwarder struct {
|
|
lock sync.Mutex
|
|
// stores data expected from the stream per port
|
|
expected map[int32]string
|
|
// stores data received from the stream per port
|
|
received map[int32]string
|
|
// data to be sent to the stream per port
|
|
send map[int32]string
|
|
}
|
|
|
|
var _ portforward.PortForwarder = &fakePortForwarder{}
|
|
|
|
func (pf *fakePortForwarder) PortForward(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error {
|
|
defer stream.Close()
|
|
|
|
// read from the client
|
|
received := make([]byte, len(pf.expected[port]))
|
|
n, err := stream.Read(received)
|
|
if err != nil {
|
|
return fmt.Errorf("error reading from client for port %d: %v", port, err)
|
|
}
|
|
if n != len(pf.expected[port]) {
|
|
return fmt.Errorf("unexpected length read from client for port %d: got %d, expected %d. data=%q", port, n, len(pf.expected[port]), string(received))
|
|
}
|
|
|
|
// store the received content
|
|
pf.lock.Lock()
|
|
pf.received[port] = string(received)
|
|
pf.lock.Unlock()
|
|
|
|
// send the hardcoded data to the client
|
|
io.Copy(stream, strings.NewReader(pf.send[port]))
|
|
|
|
return nil
|
|
}
|
|
|
|
// fakePortForwardServer creates an HTTP server that can handle port forwarding
|
|
// requests.
|
|
func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedFromClient map[int32]string) http.HandlerFunc {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
pf := &fakePortForwarder{
|
|
expected: expectedFromClient,
|
|
received: make(map[int32]string),
|
|
send: serverSends,
|
|
}
|
|
portforward.ServePortForward(w, req, pf, "pod", "uid", nil, 0, 10*time.Second, portforward.SupportedProtocols)
|
|
|
|
for port, expected := range expectedFromClient {
|
|
actual, ok := pf.received[port]
|
|
if !ok {
|
|
t.Errorf("%s: server didn't receive any data for port %d", testName, port)
|
|
continue
|
|
}
|
|
|
|
if expected != actual {
|
|
t.Errorf("%s: server expected to receive %q, got %q for port %d", testName, expected, actual, port)
|
|
}
|
|
}
|
|
|
|
for port, actual := range pf.received {
|
|
if _, ok := expectedFromClient[port]; !ok {
|
|
t.Errorf("%s: server unexpectedly received %q for port %d", testName, actual, port)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestForwardPorts(t *testing.T) {
|
|
tests := map[string]struct {
|
|
ports []string
|
|
clientSends map[int32]string
|
|
serverSends map[int32]string
|
|
}{
|
|
"forward 1 port with no data either direction": {
|
|
ports: []string{"5000"},
|
|
},
|
|
"forward 2 ports with bidirectional data": {
|
|
ports: []string{"5001", "6000"},
|
|
clientSends: map[int32]string{
|
|
5001: "abcd",
|
|
6000: "ghij",
|
|
},
|
|
serverSends: map[int32]string{
|
|
5001: "1234",
|
|
6000: "5678",
|
|
},
|
|
},
|
|
}
|
|
|
|
for testName, test := range tests {
|
|
server := httptest.NewServer(fakePortForwardServer(t, testName, test.serverSends, test.clientSends))
|
|
|
|
transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
url, _ := url.Parse(server.URL)
|
|
dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", url)
|
|
|
|
stopChan := make(chan struct{}, 1)
|
|
readyChan := make(chan struct{})
|
|
|
|
pf, err := New(dialer, test.ports, stopChan, readyChan, os.Stdout, os.Stderr)
|
|
if err != nil {
|
|
t.Fatalf("%s: unexpected error calling New: %v", testName, err)
|
|
}
|
|
|
|
doneChan := make(chan error)
|
|
go func() {
|
|
doneChan <- pf.ForwardPorts()
|
|
}()
|
|
<-pf.Ready
|
|
|
|
for port, data := range test.clientSends {
|
|
clientConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
|
if err != nil {
|
|
t.Errorf("%s: error dialing %d: %s", testName, port, err)
|
|
server.Close()
|
|
continue
|
|
}
|
|
defer clientConn.Close()
|
|
|
|
n, err := clientConn.Write([]byte(data))
|
|
if err != nil && err != io.EOF {
|
|
t.Errorf("%s: Error sending data '%s': %s", testName, data, err)
|
|
server.Close()
|
|
continue
|
|
}
|
|
if n == 0 {
|
|
t.Errorf("%s: unexpected write of 0 bytes", testName)
|
|
server.Close()
|
|
continue
|
|
}
|
|
b := make([]byte, 4)
|
|
n, err = clientConn.Read(b)
|
|
if err != nil && err != io.EOF {
|
|
t.Errorf("%s: Error reading data: %s", testName, err)
|
|
server.Close()
|
|
continue
|
|
}
|
|
if !bytes.Equal([]byte(test.serverSends[port]), b) {
|
|
t.Errorf("%s: expected to read '%s', got '%s'", testName, test.serverSends[port], b)
|
|
server.Close()
|
|
continue
|
|
}
|
|
}
|
|
// tell r.ForwardPorts to stop
|
|
close(stopChan)
|
|
|
|
// wait for r.ForwardPorts to actually return
|
|
err = <-doneChan
|
|
if err != nil {
|
|
t.Errorf("%s: unexpected error: %s", testName, err)
|
|
}
|
|
server.Close()
|
|
}
|
|
|
|
}
|
|
|
|
func TestForwardPortsReturnsErrorWhenAllBindsFailed(t *testing.T) {
|
|
server := httptest.NewServer(fakePortForwardServer(t, "allBindsFailed", nil, nil))
|
|
defer server.Close()
|
|
|
|
transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
url, _ := url.Parse(server.URL)
|
|
dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", url)
|
|
|
|
stopChan1 := make(chan struct{}, 1)
|
|
defer close(stopChan1)
|
|
readyChan1 := make(chan struct{})
|
|
|
|
pf1, err := New(dialer, []string{"5555"}, stopChan1, readyChan1, os.Stdout, os.Stderr)
|
|
if err != nil {
|
|
t.Fatalf("error creating pf1: %v", err)
|
|
}
|
|
go pf1.ForwardPorts()
|
|
<-pf1.Ready
|
|
|
|
stopChan2 := make(chan struct{}, 1)
|
|
readyChan2 := make(chan struct{})
|
|
pf2, err := New(dialer, []string{"5555"}, stopChan2, readyChan2, os.Stdout, os.Stderr)
|
|
if err != nil {
|
|
t.Fatalf("error creating pf2: %v", err)
|
|
}
|
|
if err := pf2.ForwardPorts(); err == nil {
|
|
t.Fatal("expected non-nil error for pf2.ForwardPorts")
|
|
}
|
|
}
|