/* Copyright 2015 The Kubernetes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package tests import ( "bytes" "fmt" "io" "net" "net/http" "net/http/httptest" "net/url" "os" "strings" "sync" "testing" "time" "k8s.io/apimachinery/pkg/types" restclient "k8s.io/client-go/rest" . "k8s.io/client-go/tools/portforward" "k8s.io/client-go/transport/spdy" "k8s.io/kubernetes/pkg/kubelet/server/portforward" ) // fakePortForwarder simulates port forwarding for testing. It implements // portforward.PortForwarder. type fakePortForwarder struct { lock sync.Mutex // stores data expected from the stream per port expected map[int32]string // stores data received from the stream per port received map[int32]string // data to be sent to the stream per port send map[int32]string } var _ portforward.PortForwarder = &fakePortForwarder{} func (pf *fakePortForwarder) PortForward(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { defer stream.Close() // read from the client received := make([]byte, len(pf.expected[port])) n, err := stream.Read(received) if err != nil { return fmt.Errorf("error reading from client for port %d: %v", port, err) } if n != len(pf.expected[port]) { return fmt.Errorf("unexpected length read from client for port %d: got %d, expected %d. data=%q", port, n, len(pf.expected[port]), string(received)) } // store the received content pf.lock.Lock() pf.received[port] = string(received) pf.lock.Unlock() // send the hardcoded data to the client io.Copy(stream, strings.NewReader(pf.send[port])) return nil } // fakePortForwardServer creates an HTTP server that can handle port forwarding // requests. func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedFromClient map[int32]string) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { pf := &fakePortForwarder{ expected: expectedFromClient, received: make(map[int32]string), send: serverSends, } portforward.ServePortForward(w, req, pf, "pod", "uid", nil, 0, 10*time.Second, portforward.SupportedProtocols) for port, expected := range expectedFromClient { actual, ok := pf.received[port] if !ok { t.Errorf("%s: server didn't receive any data for port %d", testName, port) continue } if expected != actual { t.Errorf("%s: server expected to receive %q, got %q for port %d", testName, expected, actual, port) } } for port, actual := range pf.received { if _, ok := expectedFromClient[port]; !ok { t.Errorf("%s: server unexpectedly received %q for port %d", testName, actual, port) } } }) } func TestForwardPorts(t *testing.T) { tests := map[string]struct { ports []string clientSends map[int32]string serverSends map[int32]string }{ "forward 1 port with no data either direction": { ports: []string{"5000"}, }, "forward 2 ports with bidirectional data": { ports: []string{"5001", "6000"}, clientSends: map[int32]string{ 5001: "abcd", 6000: "ghij", }, serverSends: map[int32]string{ 5001: "1234", 6000: "5678", }, }, } for testName, test := range tests { server := httptest.NewServer(fakePortForwardServer(t, testName, test.serverSends, test.clientSends)) transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{}) if err != nil { t.Fatal(err) } url, _ := url.Parse(server.URL) dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", url) stopChan := make(chan struct{}, 1) readyChan := make(chan struct{}) pf, err := New(dialer, test.ports, stopChan, readyChan, os.Stdout, os.Stderr) if err != nil { t.Fatalf("%s: unexpected error calling New: %v", testName, err) } doneChan := make(chan error) go func() { doneChan <- pf.ForwardPorts() }() <-pf.Ready for port, data := range test.clientSends { clientConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) if err != nil { t.Errorf("%s: error dialing %d: %s", testName, port, err) server.Close() continue } defer clientConn.Close() n, err := clientConn.Write([]byte(data)) if err != nil && err != io.EOF { t.Errorf("%s: Error sending data '%s': %s", testName, data, err) server.Close() continue } if n == 0 { t.Errorf("%s: unexpected write of 0 bytes", testName) server.Close() continue } b := make([]byte, 4) n, err = clientConn.Read(b) if err != nil && err != io.EOF { t.Errorf("%s: Error reading data: %s", testName, err) server.Close() continue } if !bytes.Equal([]byte(test.serverSends[port]), b) { t.Errorf("%s: expected to read '%s', got '%s'", testName, test.serverSends[port], b) server.Close() continue } } // tell r.ForwardPorts to stop close(stopChan) // wait for r.ForwardPorts to actually return err = <-doneChan if err != nil { t.Errorf("%s: unexpected error: %s", testName, err) } server.Close() } } func TestForwardPortsReturnsErrorWhenAllBindsFailed(t *testing.T) { server := httptest.NewServer(fakePortForwardServer(t, "allBindsFailed", nil, nil)) defer server.Close() transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{}) if err != nil { t.Fatal(err) } url, _ := url.Parse(server.URL) dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", url) stopChan1 := make(chan struct{}, 1) defer close(stopChan1) readyChan1 := make(chan struct{}) pf1, err := New(dialer, []string{"5555"}, stopChan1, readyChan1, os.Stdout, os.Stderr) if err != nil { t.Fatalf("error creating pf1: %v", err) } go pf1.ForwardPorts() <-pf1.Ready stopChan2 := make(chan struct{}, 1) readyChan2 := make(chan struct{}) pf2, err := New(dialer, []string{"5555"}, stopChan2, readyChan2, os.Stdout, os.Stderr) if err != nil { t.Fatalf("error creating pf2: %v", err) } if err := pf2.ForwardPorts(); err == nil { t.Fatal("expected non-nil error for pf2.ForwardPorts") } }