diff --git a/pkg/api/types.go b/pkg/api/types.go index 163e511522..858feb2d50 100644 --- a/pkg/api/types.go +++ b/pkg/api/types.go @@ -1951,14 +1951,24 @@ const ( // Command to run for remote command execution ExecCommandParamm = "command" - StreamType = "streamType" - StreamTypeStdin = "stdin" + // Name of header that specifies stream type + StreamType = "streamType" + // Value for streamType header for stdin stream + StreamTypeStdin = "stdin" + // Value for streamType header for stdout stream StreamTypeStdout = "stdout" + // Value for streamType header for stderr stream StreamTypeStderr = "stderr" - StreamTypeData = "data" - StreamTypeError = "error" + // Value for streamType header for data stream + StreamTypeData = "data" + // Value for streamType header for error stream + StreamTypeError = "error" + // Name of header that specifies the port being forwarded PortHeader = "port" + // Name of header that specifies a request ID used to associate the error + // and data streams for a single forwarded connection + PortForwardRequestIDHeader = "requestID" ) // Similarly to above, these are constants to support HTTP PATCH utilized by diff --git a/pkg/client/unversioned/portforward/portforward.go b/pkg/client/unversioned/portforward/portforward.go index 5efc21565d..925c5de402 100644 --- a/pkg/client/unversioned/portforward/portforward.go +++ b/pkg/client/unversioned/portforward/portforward.go @@ -25,10 +25,12 @@ import ( "net/http" "strconv" "strings" + "sync" "github.com/golang/glog" "k8s.io/kubernetes/pkg/api" client "k8s.io/kubernetes/pkg/client/unversioned" + "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/httpstream" "k8s.io/kubernetes/pkg/util/httpstream/spdy" ) @@ -51,10 +53,12 @@ type PortForwarder struct { ports []ForwardedPort stopChan <-chan struct{} - streamConn httpstream.Connection - listeners []io.Closer - upgrader upgrader - Ready chan struct{} + streamConn httpstream.Connection + listeners []io.Closer + upgrader upgrader + Ready chan struct{} + requestIDLock sync.Mutex + requestID int } // ForwardedPort contains a Local:Remote port pairing. @@ -145,7 +149,7 @@ func (pf *PortForwarder) ForwardPorts() error { var err error pf.streamConn, err = pf.upgrader.upgrade(pf.req, pf.config) if err != nil { - return fmt.Errorf("Error upgrading connection: %s", err) + return fmt.Errorf("error upgrading connection: %s", err) } defer pf.streamConn.Close() @@ -179,7 +183,7 @@ func (pf *PortForwarder) forward() error { select { case <-pf.stopChan: case <-pf.streamConn.CloseChan(): - glog.Errorf("Lost connection to pod") + util.HandleError(errors.New("lost connection to pod")) } return nil @@ -213,7 +217,7 @@ func (pf *PortForwarder) listenOnPortAndAddress(port *ForwardedPort, protocol st func (pf *PortForwarder) getListener(protocol string, hostname string, port *ForwardedPort) (net.Listener, error) { listener, err := net.Listen(protocol, fmt.Sprintf("%s:%d", hostname, port.Local)) if err != nil { - glog.Errorf("Unable to create listener: Error %s", err) + util.HandleError(fmt.Errorf("Unable to create listener: Error %s", err)) return nil, err } listenerAddress := listener.Addr().String() @@ -237,7 +241,7 @@ func (pf *PortForwarder) waitForConnection(listener net.Listener, port Forwarded if err != nil { // TODO consider using something like https://github.com/hydrogen18/stoppableListener? if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") { - glog.Errorf("Error accepting connection on port %d: %v", port.Local, err) + util.HandleError(fmt.Errorf("Error accepting connection on port %d: %v", port.Local, err)) } return } @@ -245,6 +249,14 @@ func (pf *PortForwarder) waitForConnection(listener net.Listener, port Forwarded } } +func (pf *PortForwarder) nextRequestID() int { + pf.requestIDLock.Lock() + defer pf.requestIDLock.Unlock() + id := pf.requestID + pf.requestID++ + return id +} + // handleConnection copies data between the local connection and the stream to // the remote server. func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) { @@ -252,65 +264,76 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) { glog.Infof("Handling connection for %d", port.Local) - errorChan := make(chan error) - doneChan := make(chan struct{}, 2) + requestID := pf.nextRequestID() // create error stream headers := http.Header{} headers.Set(api.StreamType, api.StreamTypeError) headers.Set(api.PortHeader, fmt.Sprintf("%d", port.Remote)) + headers.Set(api.PortForwardRequestIDHeader, strconv.Itoa(requestID)) errorStream, err := pf.streamConn.CreateStream(headers) if err != nil { - glog.Errorf("Error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err) + util.HandleError(fmt.Errorf("error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err)) return } - defer errorStream.Reset() + // we're not writing to this stream + errorStream.Close() + + errorChan := make(chan error) go func() { message, err := ioutil.ReadAll(errorStream) - if err != nil && err != io.EOF { - errorChan <- fmt.Errorf("Error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err) - } - if len(message) > 0 { - errorChan <- fmt.Errorf("An error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message)) + switch { + case err != nil: + errorChan <- fmt.Errorf("error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err) + case len(message) > 0: + errorChan <- fmt.Errorf("an error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message)) } + close(errorChan) }() // create data stream headers.Set(api.StreamType, api.StreamTypeData) dataStream, err := pf.streamConn.CreateStream(headers) if err != nil { - glog.Errorf("Error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err) + util.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err)) return } - // Send a Reset when this function exits to completely tear down the stream here - // and in the remote server. - defer dataStream.Reset() + + localError := make(chan struct{}) + remoteDone := make(chan struct{}) go func() { - // Copy from the remote side to the local port. We won't get an EOF from - // the server as it has no way of knowing when to close the stream. We'll - // take care of closing both ends of the stream with the call to - // stream.Reset() when this function exits. - if _, err := io.Copy(conn, dataStream); err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") { - glog.Errorf("Error copying from remote stream to local connection: %v", err) + // Copy from the remote side to the local port. + if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(err.Error(), "use of closed network connection") { + util.HandleError(fmt.Errorf("error copying from remote stream to local connection: %v", err)) } - doneChan <- struct{}{} + + // inform the select below that the remote copy is done + close(remoteDone) }() go func() { - // Copy from the local port to the remote side. Here we will be able to know - // when the Copy gets an EOF from conn, as that will happen as soon as conn is - // closed (i.e. client disconnected). - if _, err := io.Copy(dataStream, conn); err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") { - glog.Errorf("Error copying from local connection to remote stream: %v", err) + // inform server we're not sending any more data after copy unblocks + defer dataStream.Close() + + // Copy from the local port to the remote side. + if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(err.Error(), "use of closed network connection") { + util.HandleError(fmt.Errorf("error copying from local connection to remote stream: %v", err)) + // break out of the select below without waiting for the other copy to finish + close(localError) } - doneChan <- struct{}{} }() + // wait for either a local->remote error or for copying from remote->local to finish select { - case err := <-errorChan: - glog.Error(err) - case <-doneChan: + case <-remoteDone: + case <-localError: + } + + // always expect something on errorChan (it may be nil) + err = <-errorChan + if err != nil { + util.HandleError(err) } } @@ -318,7 +341,7 @@ func (pf *PortForwarder) Close() { // stop all listeners for _, l := range pf.listeners { if err := l.Close(); err != nil { - glog.Errorf("Error closing listener: %v", err) + util.HandleError(fmt.Errorf("error closing listener: %v", err)) } } } diff --git a/pkg/client/unversioned/portforward/portforward_test.go b/pkg/client/unversioned/portforward/portforward_test.go index f5e6c639fa..0bd2ed88ad 100644 --- a/pkg/client/unversioned/portforward/portforward_test.go +++ b/pkg/client/unversioned/portforward/portforward_test.go @@ -18,20 +18,21 @@ package portforward import ( "bytes" - "errors" "fmt" "io" "net" "net/http" + "net/http/httptest" + "net/url" "reflect" "strings" "sync" "testing" "time" - "k8s.io/kubernetes/pkg/api" client "k8s.io/kubernetes/pkg/client/unversioned" - "k8s.io/kubernetes/pkg/util/httpstream" + "k8s.io/kubernetes/pkg/kubelet" + "k8s.io/kubernetes/pkg/types" ) func TestParsePortsAndNew(t *testing.T) { @@ -110,109 +111,6 @@ func TestParsePortsAndNew(t *testing.T) { } } -type fakeUpgrader struct { - conn *fakeUpgradeConnection - err error -} - -func (u *fakeUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) { - return u.conn, u.err -} - -type fakeUpgradeConnection struct { - closeCalled bool - lock sync.Mutex - streams map[string]*fakeUpgradeStream - portData map[string]string -} - -func newFakeUpgradeConnection() *fakeUpgradeConnection { - return &fakeUpgradeConnection{ - streams: make(map[string]*fakeUpgradeStream), - portData: make(map[string]string), - } -} - -func (c *fakeUpgradeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) { - c.lock.Lock() - defer c.lock.Unlock() - - stream := &fakeUpgradeStream{} - c.streams[headers.Get(api.PortHeader)] = stream - // only simulate data on the data stream for now, not the error stream - if headers.Get(api.StreamType) == api.StreamTypeData { - stream.data = c.portData[headers.Get(api.PortHeader)] - } - - return stream, nil -} - -func (c *fakeUpgradeConnection) Close() error { - c.lock.Lock() - defer c.lock.Unlock() - - c.closeCalled = true - return nil -} - -func (c *fakeUpgradeConnection) CloseChan() <-chan bool { - return make(chan bool) -} - -func (c *fakeUpgradeConnection) SetIdleTimeout(timeout time.Duration) { -} - -type fakeUpgradeStream struct { - readCalled bool - writeCalled bool - dataWritten []byte - closeCalled bool - resetCalled bool - data string - lock sync.Mutex -} - -func (s *fakeUpgradeStream) Read(p []byte) (int, error) { - s.lock.Lock() - defer s.lock.Unlock() - s.readCalled = true - b := []byte(s.data) - n := copy(p, b) - // Indicate we returned all the data, and have no more data (EOF) - // Returning an EOF here will cause the port forwarder to immediately terminate, which is correct when we have no more data to send - return n, io.EOF -} - -func (s *fakeUpgradeStream) Write(p []byte) (int, error) { - s.lock.Lock() - defer s.lock.Unlock() - s.writeCalled = true - s.dataWritten = append(s.dataWritten, p...) - // Indicate the stream accepted all the data, and can accept more (no err) - // Returning an EOF here will cause the port forwarder to immediately terminate, which is incorrect, in case someone writes more data - return len(p), nil -} - -func (s *fakeUpgradeStream) Close() error { - s.lock.Lock() - defer s.lock.Unlock() - s.closeCalled = true - return nil -} - -func (s *fakeUpgradeStream) Reset() error { - s.lock.Lock() - defer s.lock.Unlock() - s.resetCalled = true - return nil -} - -func (s *fakeUpgradeStream) Headers() http.Header { - s.lock.Lock() - defer s.lock.Unlock() - return http.Header{} -} - type GetListenerTestCase struct { Hostname string Protocol string @@ -295,55 +193,119 @@ func TestGetListener(t *testing.T) { } } +// fakePortForwarder simulates port forwarding for testing. It implements +// kubelet.PortForwarder. +type fakePortForwarder struct { + lock sync.Mutex + // stores data received from the stream per port + received map[uint16]string + // data to be sent to the stream per port + send map[uint16]string +} + +var _ kubelet.PortForwarder = &fakePortForwarder{} + +func (pf *fakePortForwarder) PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error { + defer stream.Close() + + var wg sync.WaitGroup + + // client -> server + wg.Add(1) + go func() { + defer wg.Done() + + // copy from stream into a buffer + received := new(bytes.Buffer) + io.Copy(received, stream) + + // store the received content + pf.lock.Lock() + pf.received[port] = received.String() + pf.lock.Unlock() + }() + + // server -> client + wg.Add(1) + go func() { + defer wg.Done() + + // send the hardcoded data to the stream + io.Copy(stream, strings.NewReader(pf.send[port])) + }() + + wg.Wait() + + return nil +} + +// fakePortForwardServer creates an HTTP server that can handle port forwarding +// requests. +func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedFromClient map[uint16]string) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + pf := &fakePortForwarder{ + received: make(map[uint16]string), + send: serverSends, + } + kubelet.ServePortForward(w, req, pf, "pod", "uid", 0, 10*time.Second) + + for port, expected := range expectedFromClient { + actual, ok := pf.received[port] + if !ok { + t.Errorf("%s: server didn't receive any data for port %d", testName, port) + continue + } + + if expected != actual { + t.Errorf("%s: server expected to receive %q, got %q for port %d", testName, expected, actual, port) + } + } + + for port, actual := range pf.received { + if _, ok := expectedFromClient[port]; !ok { + t.Errorf("%s: server unexpectedly received %q for port %d", testName, actual, port) + } + } + }) +} + func TestForwardPorts(t *testing.T) { - testCases := []struct { - Upgrader *fakeUpgrader - Ports []string - Send map[uint16]string - Receive map[uint16]string - Err bool + tests := map[string]struct { + ports []string + clientSends map[uint16]string + serverSends map[uint16]string }{ - { - Upgrader: &fakeUpgrader{err: errors.New("bail")}, - Err: true, + "forward 1 port with no data either direction": { + ports: []string{"5000"}, }, - { - Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()}, - Ports: []string{"5000"}, - }, - { - Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()}, - Ports: []string{"5001", "6000"}, - Send: map[uint16]string{ + "forward 2 ports with bidirectional data": { + ports: []string{"5001", "6000"}, + clientSends: map[uint16]string{ 5001: "abcd", 6000: "ghij", }, - Receive: map[uint16]string{ + serverSends: map[uint16]string{ 5001: "1234", 6000: "5678", }, }, } - for i, testCase := range testCases { + for testName, test := range tests { + server := httptest.NewServer(fakePortForwardServer(t, testName, test.serverSends, test.clientSends)) + url, _ := url.ParseRequestURI(server.URL) + c := client.NewRESTClient(url, "x", nil, -1, -1) + req := c.Post().Resource("testing") + + conf := &client.Config{ + Host: server.URL, + } + stopChan := make(chan struct{}, 1) - pf, err := New(&client.Request{}, &client.Config{}, testCase.Ports, stopChan) - hasErr := err != nil - if hasErr != testCase.Err { - t.Fatalf("%d: New: expected %t, got %t: %v", i, testCase.Err, hasErr, err) - } - if pf == nil { - continue - } - pf.upgrader = testCase.Upgrader - if testCase.Upgrader.err != nil { - err := pf.ForwardPorts() - hasErr := err != nil - if hasErr != testCase.Err { - t.Fatalf("%d: ForwardPorts: expected %t, got %t: %v", i, testCase.Err, hasErr, err) - } - continue + pf, err := New(req, conf, test.ports, stopChan) + if err != nil { + t.Fatalf("%s: unexpected error calling New: %v", testName, err) } doneChan := make(chan error) @@ -352,65 +314,70 @@ func TestForwardPorts(t *testing.T) { }() <-pf.Ready - conn := testCase.Upgrader.conn - - for port, data := range testCase.Send { - conn.lock.Lock() - conn.portData[fmt.Sprintf("%d", port)] = testCase.Receive[port] - conn.lock.Unlock() - + for port, data := range test.clientSends { clientConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) if err != nil { - t.Fatalf("%d: error dialing %d: %s", i, port, err) + t.Errorf("%s: error dialing %d: %s", testName, port, err) + server.Close() + continue } defer clientConn.Close() n, err := clientConn.Write([]byte(data)) if err != nil && err != io.EOF { - t.Fatalf("%d: Error sending data '%s': %s", i, data, err) + t.Errorf("%s: Error sending data '%s': %s", testName, data, err) + server.Close() + continue } if n == 0 { - t.Fatalf("%d: unexpected write of 0 bytes", i) + t.Errorf("%s: unexpected write of 0 bytes", testName) + server.Close() + continue } b := make([]byte, 4) n, err = clientConn.Read(b) if err != nil && err != io.EOF { - t.Fatalf("%d: Error reading data: %s", i, err) + t.Errorf("%s: Error reading data: %s", testName, err) + server.Close() + continue } - if !bytes.Equal([]byte(testCase.Receive[port]), b) { - t.Fatalf("%d: expected to read '%s', got '%s'", i, testCase.Receive[port], b) + if !bytes.Equal([]byte(test.serverSends[port]), b) { + t.Errorf("%s: expected to read '%s', got '%s'", testName, test.serverSends[port], b) + server.Close() + continue } } - // tell r.ForwardPorts to stop close(stopChan) // wait for r.ForwardPorts to actually return err = <-doneChan if err != nil { - t.Fatalf("%d: unexpected error: %s", i, err) - } - - if e, a := len(testCase.Send), len(conn.streams); e != a { - t.Fatalf("%d: expected %d streams to be created, got %d", i, e, a) - } - - if !conn.closeCalled { - t.Fatalf("%d: expected conn closure", i) + t.Errorf("%s: unexpected error: %s", testName, err) } + server.Close() } } func TestForwardPortsReturnsErrorWhenAllBindsFailed(t *testing.T) { + server := httptest.NewServer(fakePortForwardServer(t, "allBindsFailed", nil, nil)) + defer server.Close() + url, _ := url.ParseRequestURI(server.URL) + c := client.NewRESTClient(url, "x", nil, -1, -1) + req := c.Post().Resource("testing") + + conf := &client.Config{ + Host: server.URL, + } + stopChan1 := make(chan struct{}, 1) defer close(stopChan1) - pf1, err := New(&client.Request{}, &client.Config{}, []string{"5555"}, stopChan1) + pf1, err := New(req, conf, []string{"5555"}, stopChan1) if err != nil { t.Fatalf("error creating pf1: %v", err) } - pf1.upgrader = &fakeUpgrader{conn: newFakeUpgradeConnection()} go pf1.ForwardPorts() <-pf1.Ready @@ -419,7 +386,6 @@ func TestForwardPortsReturnsErrorWhenAllBindsFailed(t *testing.T) { if err != nil { t.Fatalf("error creating pf2: %v", err) } - pf2.upgrader = &fakeUpgrader{conn: newFakeUpgradeConnection()} if err := pf2.ForwardPorts(); err == nil { t.Fatal("expected non-nil error for pf2.ForwardPorts") } diff --git a/pkg/kubelet/dockertools/manager.go b/pkg/kubelet/dockertools/manager.go index 74631ec452..7edf98c8b2 100644 --- a/pkg/kubelet/dockertools/manager.go +++ b/pkg/kubelet/dockertools/manager.go @@ -1179,16 +1179,39 @@ func (dm *DockerManager) PortForward(pod *kubecontainer.Pod, port uint16, stream } containerPid := container.State.Pid - // TODO what if the host doesn't have it??? - _, lookupErr := exec.LookPath("socat") + socatPath, lookupErr := exec.LookPath("socat") if lookupErr != nil { - return fmt.Errorf("Unable to do port forwarding: socat not found.") + return fmt.Errorf("unable to do port forwarding: socat not found.") } - args := []string{"-t", fmt.Sprintf("%d", containerPid), "-n", "socat", "-", fmt.Sprintf("TCP4:localhost:%d", port)} - // TODO use exec.LookPath - command := exec.Command("nsenter", args...) - command.Stdin = stream + + args := []string{"-t", fmt.Sprintf("%d", containerPid), "-n", socatPath, "-", fmt.Sprintf("TCP4:localhost:%d", port)} + + nsenterPath, lookupErr := exec.LookPath("nsenter") + if lookupErr != nil { + return fmt.Errorf("unable to do port forwarding: nsenter not found.") + } + + command := exec.Command(nsenterPath, args...) command.Stdout = stream + + // If we use Stdin, command.Run() won't return until the goroutine that's copying + // from stream finishes. Unfortunately, if you have a client like telnet connected + // via port forwarding, as long as the user's telnet client is connected to the user's + // local listener that port forwarding sets up, the telnet session never exits. This + // means that even if socat has finished running, command.Run() won't ever return + // (because the client still has the connection and stream open). + // + // The work around is to use StdinPipe(), as Wait() (called by Run()) closes the pipe + // when the command (socat) exits. + inPipe, err := command.StdinPipe() + if err != nil { + return fmt.Errorf("unable to do port forwarding: error creating stdin pipe: %v", err) + } + go func() { + io.Copy(inPipe, stream) + inPipe.Close() + }() + return command.Run() } diff --git a/pkg/kubelet/dockertools/manager_test.go b/pkg/kubelet/dockertools/manager_test.go index 93c5e7aaa0..ed92253ff4 100644 --- a/pkg/kubelet/dockertools/manager_test.go +++ b/pkg/kubelet/dockertools/manager_test.go @@ -1762,6 +1762,12 @@ func TestSyncPodEventHandlerFails(t *testing.T) { } } +type fakeReadWriteCloser struct{} + +func (*fakeReadWriteCloser) Read([]byte) (int, error) { return 0, nil } +func (*fakeReadWriteCloser) Write([]byte) (int, error) { return 0, nil } +func (*fakeReadWriteCloser) Close() error { return nil } + func TestPortForwardNoSuchContainer(t *testing.T) { dm, _ := newTestDockerManager() @@ -1774,7 +1780,8 @@ func TestPortForwardNoSuchContainer(t *testing.T) { Containers: nil, }, 5000, - nil, + // need a valid io.ReadWriteCloser here + &fakeReadWriteCloser{}, ) if err == nil { t.Fatal("unexpected non-error") diff --git a/pkg/kubelet/rkt/rkt.go b/pkg/kubelet/rkt/rkt.go index df2755222a..00475d8310 100644 --- a/pkg/kubelet/rkt/rkt.go +++ b/pkg/kubelet/rkt/rkt.go @@ -1211,19 +1211,39 @@ func (r *runtime) PortForward(pod *kubecontainer.Pod, port uint16, stream io.Rea return err } - _, lookupErr := exec.LookPath("socat") + socatPath, lookupErr := exec.LookPath("socat") if lookupErr != nil { return fmt.Errorf("unable to do port forwarding: socat not found.") } - args := []string{"-t", fmt.Sprintf("%d", info.pid), "-n", "socat", "-", fmt.Sprintf("TCP4:localhost:%d", port)} - _, lookupErr = exec.LookPath("nsenter") + args := []string{"-t", fmt.Sprintf("%d", info.pid), "-n", socatPath, "-", fmt.Sprintf("TCP4:localhost:%d", port)} + + nsenterPath, lookupErr := exec.LookPath("nsenter") if lookupErr != nil { return fmt.Errorf("unable to do port forwarding: nsenter not found.") } - command := exec.Command("nsenter", args...) - command.Stdin = stream + + command := exec.Command(nsenterPath, args...) command.Stdout = stream + + // If we use Stdin, command.Run() won't return until the goroutine that's copying + // from stream finishes. Unfortunately, if you have a client like telnet connected + // via port forwarding, as long as the user's telnet client is connected to the user's + // local listener that port forwarding sets up, the telnet session never exits. This + // means that even if socat has finished running, command.Run() won't ever return + // (because the client still has the connection and stream open). + // + // The work around is to use StdinPipe(), as Wait() (called by Run()) closes the pipe + // when the command (socat) exits. + inPipe, err := command.StdinPipe() + if err != nil { + return fmt.Errorf("unable to do port forwarding: error creating stdin pipe: %v", err) + } + go func() { + io.Copy(inPipe, stream) + inPipe.Close() + }() + return command.Run() } diff --git a/pkg/kubelet/server.go b/pkg/kubelet/server.go index b91fc97238..32c6cb3dbb 100644 --- a/pkg/kubelet/server.go +++ b/pkg/kubelet/server.go @@ -45,6 +45,7 @@ import ( "k8s.io/kubernetes/pkg/httplog" kubecontainer "k8s.io/kubernetes/pkg/kubelet/container" "k8s.io/kubernetes/pkg/types" + "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/flushwriter" "k8s.io/kubernetes/pkg/util/httpstream" "k8s.io/kubernetes/pkg/util/httpstream/spdy" @@ -458,7 +459,7 @@ func getContainerCoordinates(request *restful.Request) (namespace, pod string, u return } -const streamCreationTimeout = 30 * time.Second +const defaultStreamCreationTimeout = 30 * time.Second func (s *Server) getAttach(request *restful.Request, response *restful.Response) { podNamespace, podID, uid, container := getContainerCoordinates(request) @@ -564,7 +565,7 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout()) // TODO make it configurable? - expired := time.NewTimer(streamCreationTimeout) + expired := time.NewTimer(defaultStreamCreationTimeout) var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream receivedStreams := 0 @@ -612,6 +613,15 @@ func getPodCoordinates(request *restful.Request) (namespace, pod string, uid typ return } +// PortForwarder knows how to forward content from a data stream to/from a port +// in a pod. +type PortForwarder interface { + // PortForwarder copies data between a data stream and a port in a pod. + PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error +} + +// getPortForward handles a new restful port forward request. It determines the +// pod name and uid and then calls ServePortForward. func (s *Server) getPortForward(request *restful.Request, response *restful.Response) { podNamespace, podID, uid := getPodCoordinates(request) pod, ok := s.host.GetPodByName(podNamespace, podID) @@ -620,80 +630,280 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp return } + podName := kubecontainer.GetPodFullName(pod) + + ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), defaultStreamCreationTimeout) +} + +// ServePortForward handles a port forwarding request. A single request is +// kept alive as long as the client is still alive and the connection has not +// been timed out due to idleness. This function handles multiple forwarded +// connections; i.e., multiple `curl http://localhost:8888/` requests will be +// handled by a single invocation of ServePortForward. +func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) { streamChan := make(chan httpstream.Stream, 1) + + glog.V(5).Infof("Upgrading port forward response") upgrader := spdy.NewResponseUpgrader() - conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream) error { - portString := stream.Headers().Get(api.PortHeader) - port, err := strconv.ParseUint(portString, 10, 16) - if err != nil { - return fmt.Errorf("Unable to parse '%s' as a port: %v", portString, err) - } - if port < 1 { - return fmt.Errorf("Port '%d' must be greater than 0", port) - } - streamChan <- stream - return nil - }) + conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan)) if conn == nil { return } defer conn.Close() - conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout()) - var dataStreamLock sync.Mutex - dataStreamChans := make(map[string]chan httpstream.Stream) + glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout) + conn.SetIdleTimeout(idleTimeout) + h := &portForwardStreamHandler{ + conn: conn, + streamChan: streamChan, + streamPairs: make(map[string]*portForwardStreamPair), + streamCreationTimeout: streamCreationTimeout, + pod: podName, + uid: uid, + forwarder: portForwarder, + } + h.run() +} + +// portForwardStreamReceived is the httpstream.NewStreamHandler for port +// forward streams. It checks each stream's port and stream type headers, +// rejecting any streams that with missing or invalid values. Each valid +// stream is sent to the streams channel. +func portForwardStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream) error { + return func(stream httpstream.Stream) error { + // make sure it has a valid port header + portString := stream.Headers().Get(api.PortHeader) + if len(portString) == 0 { + return fmt.Errorf("%q header is required", api.PortHeader) + } + port, err := strconv.ParseUint(portString, 10, 16) + if err != nil { + return fmt.Errorf("unable to parse %q as a port: %v", portString, err) + } + if port < 1 { + return fmt.Errorf("port %q must be > 0", portString) + } + + // make sure it has a valid stream type header + streamType := stream.Headers().Get(api.StreamType) + if len(streamType) == 0 { + return fmt.Errorf("%q header is required", api.StreamType) + } + if streamType != api.StreamTypeError && streamType != api.StreamTypeData { + return fmt.Errorf("invalid stream type %q", streamType) + } + + streams <- stream + return nil + } +} + +// portForwardStreamHandler is capable of processing multiple port forward +// requests over a single httpstream.Connection. +type portForwardStreamHandler struct { + conn httpstream.Connection + streamChan chan httpstream.Stream + streamPairsLock sync.RWMutex + streamPairs map[string]*portForwardStreamPair + streamCreationTimeout time.Duration + pod string + uid types.UID + forwarder PortForwarder +} + +// getStreamPair returns a portForwardStreamPair for requestID. This creates a +// new pair if one does not yet exist for the requestID. The returned bool is +// true if the pair was created. +func (h *portForwardStreamHandler) getStreamPair(requestID string) (*portForwardStreamPair, bool) { + h.streamPairsLock.Lock() + defer h.streamPairsLock.Unlock() + + if p, ok := h.streamPairs[requestID]; ok { + glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID) + return p, false + } + + glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID) + + p := newPortForwardPair(requestID) + h.streamPairs[requestID] = p + + return p, true +} + +// monitorStreamPair waits for the pair to receive both its error and data +// streams, or for the timeout to expire (whichever happens first), and then +// removes the pair. +func (h *portForwardStreamHandler) monitorStreamPair(p *portForwardStreamPair, timeout <-chan time.Time) { + select { + case <-timeout: + err := fmt.Errorf("(conn=%p, request=%s) timed out waiting for streams", h.conn, p.requestID) + util.HandleError(err) + p.printError(err.Error()) + case <-p.complete: + glog.V(5).Infof("(conn=%p, request=%s) successfully received error and data streams", h.conn, p.requestID) + } + h.removeStreamPair(p.requestID) +} + +// hasStreamPair returns a bool indicating if a stream pair for requestID +// exists. +func (h *portForwardStreamHandler) hasStreamPair(requestID string) bool { + h.streamPairsLock.RLock() + defer h.streamPairsLock.RUnlock() + + _, ok := h.streamPairs[requestID] + return ok +} + +// removeStreamPair removes the stream pair identified by requestID from streamPairs. +func (h *portForwardStreamHandler) removeStreamPair(requestID string) { + h.streamPairsLock.Lock() + defer h.streamPairsLock.Unlock() + + delete(h.streamPairs, requestID) +} + +// requestID returns the request id for stream. +func (h *portForwardStreamHandler) requestID(stream httpstream.Stream) string { + requestID := stream.Headers().Get(api.PortForwardRequestIDHeader) + if len(requestID) == 0 { + glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader) + // If we get here, it's because the connection came from an older client + // that isn't generating the request id header + // (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287) + // + // This is a best-effort attempt at supporting older clients. + // + // When there aren't concurrent new forwarded connections, each connection + // will have a pair of streams (data, error), and the stream IDs will be + // consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert + // the stream ID into a pseudo-request id by taking the stream type and + // using id = stream.Identifier() when the stream type is error, + // and id = stream.Identifier() - 2 when it's data. + // + // NOTE: this only works when there are not concurrent new streams from + // multiple forwarded connections; it's a best-effort attempt at supporting + // old clients that don't generate request ids. If there are concurrent + // new connections, it's possible that 1 connection gets streams whose IDs + // are not consecutive (e.g. 5 and 9 instead of 5 and 7). + streamType := stream.Headers().Get(api.StreamType) + switch streamType { + case api.StreamTypeError: + requestID = strconv.Itoa(int(stream.Identifier())) + case api.StreamTypeData: + requestID = strconv.Itoa(int(stream.Identifier()) - 2) + } + + glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier()) + } + return requestID +} + +// run is the main loop for the portForwardStreamHandler. It processes new +// streams, invoking portForward for each complete stream pair. The loop exits +// when the httpstream.Connection is closed. +func (h *portForwardStreamHandler) run() { + glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn) Loop: for { select { - case <-conn.CloseChan(): + case <-h.conn.CloseChan(): + glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn) break Loop - case stream := <-streamChan: + case stream := <-h.streamChan: + requestID := h.requestID(stream) streamType := stream.Headers().Get(api.StreamType) - port := stream.Headers().Get(api.PortHeader) - dataStreamLock.Lock() - switch streamType { - case "error": - ch := make(chan httpstream.Stream) - dataStreamChans[port] = ch - go waitForPortForwardDataStreamAndRun(kubecontainer.GetPodFullName(pod), uid, stream, ch, s.host) - case "data": - ch, ok := dataStreamChans[port] - if ok { - ch <- stream - delete(dataStreamChans, port) - } else { - glog.Errorf("Unable to locate data stream channel for port %s", port) - } - default: - glog.Errorf("streamType header must be 'error' or 'data', got: '%s'", streamType) - stream.Reset() + glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType) + + p, created := h.getStreamPair(requestID) + if created { + go h.monitorStreamPair(p, time.After(h.streamCreationTimeout)) + } + if complete, err := p.add(stream); err != nil { + msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err) + util.HandleError(errors.New(msg)) + p.printError(msg) + } else if complete { + go h.portForward(p) } - dataStreamLock.Unlock() } } } -func waitForPortForwardDataStreamAndRun(pod string, uid types.UID, errorStream httpstream.Stream, dataStreamChan chan httpstream.Stream, host HostInterface) { - defer errorStream.Reset() +// portForward invokes the portForwardStreamHandler's forwarder.PortForward +// function for the given stream pair. +func (h *portForwardStreamHandler) portForward(p *portForwardStreamPair) { + defer p.dataStream.Close() + defer p.errorStream.Close() - var dataStream httpstream.Stream + portString := p.dataStream.Headers().Get(api.PortHeader) + port, _ := strconv.ParseUint(portString, 10, 16) - select { - case dataStream = <-dataStreamChan: - case <-time.After(streamCreationTimeout): - errorStream.Write([]byte("Timed out waiting for data stream")) - //TODO delete from dataStreamChans[port] - return + glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString) + err := h.forwarder.PortForward(h.pod, h.uid, uint16(port), p.dataStream) + glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString) + + if err != nil { + msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err) + util.HandleError(msg) + fmt.Fprint(p.errorStream, msg.Error()) + } +} + +// portForwardStreamPair represents the error and data streams for a port +// forwarding request. +type portForwardStreamPair struct { + lock sync.RWMutex + requestID string + dataStream httpstream.Stream + errorStream httpstream.Stream + complete chan struct{} +} + +// newPortForwardPair creates a new portForwardStreamPair. +func newPortForwardPair(requestID string) *portForwardStreamPair { + return &portForwardStreamPair{ + requestID: requestID, + complete: make(chan struct{}), + } +} + +// add adds the stream to the portForwardStreamPair. If the pair already +// contains a stream for the new stream's type, an error is returned. add +// returns true if both the data and error streams for this pair have been +// received. +func (p *portForwardStreamPair) add(stream httpstream.Stream) (bool, error) { + p.lock.Lock() + defer p.lock.Unlock() + + switch stream.Headers().Get(api.StreamType) { + case api.StreamTypeError: + if p.errorStream != nil { + return false, errors.New("error stream already assigned") + } + p.errorStream = stream + case api.StreamTypeData: + if p.dataStream != nil { + return false, errors.New("data stream already assigned") + } + p.dataStream = stream } - portString := dataStream.Headers().Get(api.PortHeader) - port, _ := strconv.ParseUint(portString, 10, 16) - err := host.PortForward(pod, uid, uint16(port), dataStream) - if err != nil { - msg := fmt.Errorf("Error forwarding port %d to pod %s, uid %v: %v", port, pod, uid, err) - glog.Error(msg) - errorStream.Write([]byte(msg.Error())) + complete := p.errorStream != nil && p.dataStream != nil + if complete { + close(p.complete) + } + return complete, nil +} + +// printError writes s to p.errorStream if p.errorStream has been set. +func (p *portForwardStreamPair) printError(s string) { + p.lock.RLock() + defer p.lock.RUnlock() + if p.errorStream != nil { + fmt.Fprint(p.errorStream, s) } } diff --git a/pkg/kubelet/server_test.go b/pkg/kubelet/server_test.go index b1301d87ad..664dd18e81 100644 --- a/pkg/kubelet/server_test.go +++ b/pkg/kubelet/server_test.go @@ -1426,3 +1426,221 @@ func TestServePortForward(t *testing.T) { <-portForwardFuncDone } } + +type fakeHttpStream struct { + headers http.Header + id uint32 +} + +func newFakeHttpStream() *fakeHttpStream { + return &fakeHttpStream{ + headers: make(http.Header), + } +} + +var _ httpstream.Stream = &fakeHttpStream{} + +func (s *fakeHttpStream) Read(data []byte) (int, error) { + return 0, nil +} + +func (s *fakeHttpStream) Write(data []byte) (int, error) { + return 0, nil +} + +func (s *fakeHttpStream) Close() error { + return nil +} + +func (s *fakeHttpStream) Reset() error { + return nil +} + +func (s *fakeHttpStream) Headers() http.Header { + return s.headers +} + +func (s *fakeHttpStream) Identifier() uint32 { + return s.id +} + +func TestPortForwardStreamReceived(t *testing.T) { + tests := map[string]struct { + port string + streamType string + expectedError string + }{ + "missing port": { + expectedError: `"port" header is required`, + }, + "unable to parse port": { + port: "abc", + expectedError: `unable to parse "abc" as a port: strconv.ParseUint: parsing "abc": invalid syntax`, + }, + "negative port": { + port: "-1", + expectedError: `unable to parse "-1" as a port: strconv.ParseUint: parsing "-1": invalid syntax`, + }, + "missing stream type": { + port: "80", + expectedError: `"streamType" header is required`, + }, + "valid port with error stream": { + port: "80", + streamType: "error", + }, + "valid port with data stream": { + port: "80", + streamType: "data", + }, + "invalid stream type": { + port: "80", + streamType: "foo", + expectedError: `invalid stream type "foo"`, + }, + } + for name, test := range tests { + streams := make(chan httpstream.Stream, 1) + f := portForwardStreamReceived(streams) + stream := newFakeHttpStream() + if len(test.port) > 0 { + stream.headers.Set("port", test.port) + } + if len(test.streamType) > 0 { + stream.headers.Set("streamType", test.streamType) + } + err := f(stream) + if len(test.expectedError) > 0 { + if err == nil { + t.Errorf("%s: expected err=%q, but it was nil", name, test.expectedError) + } + if e, a := test.expectedError, err.Error(); e != a { + t.Errorf("%s: expected err=%q, got %q", name, e, a) + } + continue + } + if err != nil { + t.Errorf("%s: unexpected error %v", name, err) + continue + } + if s := <-streams; s != stream { + t.Errorf("%s: expected stream %#v, got %#v", name, stream, s) + } + } +} + +func TestGetStreamPair(t *testing.T) { + timeout := make(chan time.Time) + + h := &portForwardStreamHandler{ + streamPairs: make(map[string]*portForwardStreamPair), + } + + // test adding a new entry + p, created := h.getStreamPair("1") + if p == nil { + t.Fatalf("unexpected nil pair") + } + if !created { + t.Fatal("expected created=true") + } + if p.dataStream != nil { + t.Errorf("unexpected non-nil data stream") + } + if p.errorStream != nil { + t.Errorf("unexpected non-nil error stream") + } + + // start the monitor for this pair + monitorDone := make(chan struct{}) + go func() { + h.monitorStreamPair(p, timeout) + close(monitorDone) + }() + + if !h.hasStreamPair("1") { + t.Fatal("This should still be true") + } + + // make sure we can retrieve an existing entry + p2, created := h.getStreamPair("1") + if created { + t.Fatal("expected created=false") + } + if p != p2 { + t.Fatalf("retrieving an existing pair: expected %#v, got %#v", p, p2) + } + + // removed via complete + dataStream := newFakeHttpStream() + dataStream.headers.Set(api.StreamType, api.StreamTypeData) + complete, err := p.add(dataStream) + if err != nil { + t.Fatalf("unexpected error adding data stream to pair: %v", err) + } + if complete { + t.Fatalf("unexpected complete") + } + + errorStream := newFakeHttpStream() + errorStream.headers.Set(api.StreamType, api.StreamTypeError) + complete, err = p.add(errorStream) + if err != nil { + t.Fatalf("unexpected error adding error stream to pair: %v", err) + } + if !complete { + t.Fatal("unexpected incomplete") + } + + // make sure monitorStreamPair completed + <-monitorDone + + // make sure the pair was removed + if h.hasStreamPair("1") { + t.Fatal("expected removal of pair after both data and error streams received") + } + + // removed via timeout + p, created = h.getStreamPair("2") + if !created { + t.Fatal("expected created=true") + } + if p == nil { + t.Fatal("expected p not to be nil") + } + monitorDone = make(chan struct{}) + go func() { + h.monitorStreamPair(p, timeout) + close(monitorDone) + }() + // cause the timeout + close(timeout) + // make sure monitorStreamPair completed + <-monitorDone + if h.hasStreamPair("2") { + t.Fatal("expected stream pair to be removed") + } +} + +func TestRequestID(t *testing.T) { + h := &portForwardStreamHandler{} + + s := newFakeHttpStream() + s.headers.Set(api.StreamType, api.StreamTypeError) + s.id = 1 + if e, a := "1", h.requestID(s); e != a { + t.Errorf("expected %q, got %q", e, a) + } + + s.headers.Set(api.StreamType, api.StreamTypeData) + s.id = 3 + if e, a := "1", h.requestID(s); e != a { + t.Errorf("expected %q, got %q", e, a) + } + + s.id = 7 + s.headers.Set(api.PortForwardRequestIDHeader, "2") + if e, a := "2", h.requestID(s); e != a { + t.Errorf("expected %q, got %q", e, a) + } +} diff --git a/pkg/util/httpstream/httpstream.go b/pkg/util/httpstream/httpstream.go index a1c92dde7f..b61af6224d 100644 --- a/pkg/util/httpstream/httpstream.go +++ b/pkg/util/httpstream/httpstream.go @@ -78,6 +78,8 @@ type Stream interface { Reset() error // Headers returns the headers used to create the stream. Headers() http.Header + // Identifier returns the stream's ID. + Identifier() uint32 } // IsUpgradeRequest returns true if the given request is a connection upgrade request diff --git a/test/e2e/kubectl.go b/test/e2e/kubectl.go index 447ee1984e..46b0cf2413 100644 --- a/test/e2e/kubectl.go +++ b/test/e2e/kubectl.go @@ -60,10 +60,7 @@ const ( simplePodPort = 80 ) -var ( - portForwardRegexp = regexp.MustCompile("Forwarding from 127.0.0.1:([0-9]+) -> 80") - proxyRegexp = regexp.MustCompile("Starting to serve on 127.0.0.1:([0-9]+)") -) +var proxyRegexp = regexp.MustCompile("Starting to serve on 127.0.0.1:([0-9]+)") var _ = Describe("Kubectl client", func() { defer GinkgoRecover() @@ -200,32 +197,11 @@ var _ = Describe("Kubectl client", func() { It("should support port-forward", func() { By("forwarding the container port to a local port") - cmd := kubectlCmd("port-forward", fmt.Sprintf("--namespace=%v", ns), simplePodName, fmt.Sprintf(":%d", simplePodPort)) + cmd, listenPort := runPortForward(ns, simplePodName, simplePodPort) defer tryKill(cmd) - // This is somewhat ugly but is the only way to retrieve the port that was picked - // by the port-forward command. We don't want to hard code the port as we have no - // way of guaranteeing we can pick one that isn't in use, particularly on Jenkins. - Logf("starting port-forward command and streaming output") - stdout, stderr, err := startCmdAndStreamOutput(cmd) - if err != nil { - Failf("Failed to start port-forward command: %v", err) - } - defer stdout.Close() - defer stderr.Close() - buf := make([]byte, 128) - var n int - Logf("reading from `kubectl port-forward` command's stderr") - if n, err = stderr.Read(buf); err != nil { - Failf("Failed to read from kubectl port-forward stderr: %v", err) - } - portForwardOutput := string(buf[:n]) - match := portForwardRegexp.FindStringSubmatch(portForwardOutput) - if len(match) != 2 { - Failf("Failed to parse kubectl port-forward output: %s", portForwardOutput) - } By("curling local port output") - localAddr := fmt.Sprintf("http://localhost:%s", match[1]) + localAddr := fmt.Sprintf("http://localhost:%d", listenPort) body, err := curl(localAddr) Logf("got: %s", body) if err != nil { diff --git a/test/e2e/portforward.go b/test/e2e/portforward.go new file mode 100644 index 0000000000..2b5a2ca5ba --- /dev/null +++ b/test/e2e/portforward.go @@ -0,0 +1,234 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package e2e + +import ( + "fmt" + "io/ioutil" + "net" + "os/exec" + "regexp" + "strconv" + "strings" + + "k8s.io/kubernetes/pkg/api" + + . "github.com/onsi/ginkgo" +) + +const ( + podName = "pfpod" +) + +// TODO support other ports besides 80 +var portForwardRegexp = regexp.MustCompile("Forwarding from 127.0.0.1:([0-9]+) -> 80") + +func pfPod(expectedClientData, chunks, chunkSize, chunkIntervalMillis string) *api.Pod { + return &api.Pod{ + ObjectMeta: api.ObjectMeta{ + Name: podName, + Labels: map[string]string{"name": podName}, + }, + Spec: api.PodSpec{ + Containers: []api.Container{ + { + Name: "portforwardtester", + Image: "gcr.io/google_containers/portforwardtester:1.0", + Env: []api.EnvVar{ + { + Name: "BIND_PORT", + Value: "80", + }, + { + Name: "EXPECTED_CLIENT_DATA", + Value: expectedClientData, + }, + { + Name: "CHUNKS", + Value: chunks, + }, + { + Name: "CHUNK_SIZE", + Value: chunkSize, + }, + { + Name: "CHUNK_INTERVAL", + Value: chunkIntervalMillis, + }, + }, + }, + }, + RestartPolicy: api.RestartPolicyNever, + }, + } +} + +func runPortForward(ns, podName string, port int) (*exec.Cmd, int) { + cmd := kubectlCmd("port-forward", fmt.Sprintf("--namespace=%v", ns), podName, fmt.Sprintf(":%d", port)) + // This is somewhat ugly but is the only way to retrieve the port that was picked + // by the port-forward command. We don't want to hard code the port as we have no + // way of guaranteeing we can pick one that isn't in use, particularly on Jenkins. + Logf("starting port-forward command and streaming output") + stdout, stderr, err := startCmdAndStreamOutput(cmd) + if err != nil { + Failf("Failed to start port-forward command: %v", err) + } + defer stdout.Close() + defer stderr.Close() + + buf := make([]byte, 128) + var n int + Logf("reading from `kubectl port-forward` command's stderr") + if n, err = stderr.Read(buf); err != nil { + Failf("Failed to read from kubectl port-forward stderr: %v", err) + } + portForwardOutput := string(buf[:n]) + match := portForwardRegexp.FindStringSubmatch(portForwardOutput) + if len(match) != 2 { + Failf("Failed to parse kubectl port-forward output: %s", portForwardOutput) + } + + listenPort, err := strconv.Atoi(match[1]) + if err != nil { + Failf("Error converting %s to an int: %v", match[1], err) + } + + return cmd, listenPort +} + +var _ = Describe("Port forwarding", func() { + framework := NewFramework("port-forwarding") + + Describe("With a server that expects a client request", func() { + It("should support a client that connects, sends no data, and disconnects", func() { + By("creating the target pod") + pod := pfPod("abc", "1", "1", "1") + framework.Client.Pods(framework.Namespace.Name).Create(pod) + framework.WaitForPodRunning(pod.Name) + + By("Running 'kubectl port-forward'") + cmd, listenPort := runPortForward(framework.Namespace.Name, pod.Name, 80) + defer tryKill(cmd) + + By("Dialing the local port") + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenPort)) + if err != nil { + Failf("Couldn't connect to port %d: %v", listenPort, err) + } + + By("Closing the connection to the local port") + conn.Close() + + logOutput := runKubectl("logs", fmt.Sprintf("--namespace=%v", framework.Namespace.Name), "-f", podName) + verifyLogMessage(logOutput, "Accepted client connection") + verifyLogMessage(logOutput, "Expected to read 3 bytes from client, but got 0 instead") + }) + + It("should support a client that connects, sends data, and disconnects", func() { + By("creating the target pod") + pod := pfPod("abc", "10", "10", "100") + framework.Client.Pods(framework.Namespace.Name).Create(pod) + framework.WaitForPodRunning(pod.Name) + + By("Running 'kubectl port-forward'") + cmd, listenPort := runPortForward(framework.Namespace.Name, pod.Name, 80) + defer tryKill(cmd) + + By("Dialing the local port") + addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", listenPort)) + if err != nil { + Failf("Error resolving tcp addr: %v", err) + } + conn, err := net.DialTCP("tcp", nil, addr) + if err != nil { + Failf("Couldn't connect to port %d: %v", listenPort, err) + } + defer func() { + By("Closing the connection to the local port") + conn.Close() + }() + + By("Sending the expected data to the local port") + fmt.Fprint(conn, "abc") + + By("Closing the write half of the client's connection") + conn.CloseWrite() + + By("Reading data from the local port") + fromServer, err := ioutil.ReadAll(conn) + if err != nil { + Failf("Unexpected error reading data from the server: %v", err) + } + + if e, a := strings.Repeat("x", 100), string(fromServer); e != a { + Failf("Expected %q from server, got %q", e, a) + } + + logOutput := runKubectl("logs", fmt.Sprintf("--namespace=%v", framework.Namespace.Name), "-f", podName) + verifyLogMessage(logOutput, "^Accepted client connection$") + verifyLogMessage(logOutput, "^Received expected client data$") + verifyLogMessage(logOutput, "^Done$") + }) + }) + Describe("With a server that expects no client request", func() { + It("should support a client that connects, sends no data, and disconnects", func() { + By("creating the target pod") + pod := pfPod("", "10", "10", "100") + framework.Client.Pods(framework.Namespace.Name).Create(pod) + framework.WaitForPodRunning(pod.Name) + + By("Running 'kubectl port-forward'") + cmd, listenPort := runPortForward(framework.Namespace.Name, pod.Name, 80) + defer tryKill(cmd) + + By("Dialing the local port") + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenPort)) + if err != nil { + Failf("Couldn't connect to port %d: %v", listenPort, err) + } + defer func() { + By("Closing the connection to the local port") + conn.Close() + }() + + By("Reading data from the local port") + fromServer, err := ioutil.ReadAll(conn) + if err != nil { + Failf("Unexpected error reading data from the server: %v", err) + } + + if e, a := strings.Repeat("x", 100), string(fromServer); e != a { + Failf("Expected %q from server, got %q", e, a) + } + + logOutput := runKubectl("logs", fmt.Sprintf("--namespace=%v", framework.Namespace.Name), "-f", podName) + verifyLogMessage(logOutput, "Accepted client connection") + verifyLogMessage(logOutput, "Done") + }) + }) +}) + +func verifyLogMessage(log, expected string) { + re := regexp.MustCompile(expected) + lines := strings.Split(log, "\n") + for i := range lines { + if re.MatchString(lines[i]) { + return + } + } + Failf("Missing %q from log: %s", expected, log) +}