From ed021fed4c38b5c82696fea84b3173e98bc16922 Mon Sep 17 00:00:00 2001 From: Andy Goldstein Date: Tue, 22 Sep 2015 16:29:51 -0400 Subject: [PATCH] Port forwarding fixes Correct port-forward data copying logic so that the server closes its half of the data stream when socat exits, and the client closes its half of the data stream when it finishes writing. Modify the client to wait for both copies (client->server, server->client) to finish before it unblocks. Fix race condition in the Kubelet's handling of incoming port forward streams. Have the client generate a connectionID header to be used to associate the error and data streams for a single connection, instead of assuming that streams n and n+1 go together. Attempt to generate a pseudo connectionID in the server in the event the connectionID header isn't present (older clients); this is a best-effort approach that only really works with 1 connection at a time, whereas multiple concurrent connections will only work reliably with a newer client that is generating connectionID. --- pkg/api/types.go | 18 +- .../unversioned/portforward/portforward.go | 99 +++--- .../portforward/portforward_test.go | 300 ++++++++--------- pkg/kubelet/dockertools/manager.go | 37 +- pkg/kubelet/dockertools/manager_test.go | 9 +- pkg/kubelet/rkt/rkt.go | 30 +- pkg/kubelet/server.go | 318 +++++++++++++++--- pkg/kubelet/server_test.go | 218 ++++++++++++ pkg/util/httpstream/httpstream.go | 2 + test/e2e/kubectl.go | 30 +- test/e2e/portforward.go | 234 +++++++++++++ 11 files changed, 992 insertions(+), 303 deletions(-) create mode 100644 test/e2e/portforward.go diff --git a/pkg/api/types.go b/pkg/api/types.go index 163e511522..858feb2d50 100644 --- a/pkg/api/types.go +++ b/pkg/api/types.go @@ -1951,14 +1951,24 @@ const ( // Command to run for remote command execution ExecCommandParamm = "command" - StreamType = "streamType" - StreamTypeStdin = "stdin" + // Name of header that specifies stream type + StreamType = "streamType" + // Value for streamType header for stdin stream + StreamTypeStdin = "stdin" + // Value for streamType header for stdout stream StreamTypeStdout = "stdout" + // Value for streamType header for stderr stream StreamTypeStderr = "stderr" - StreamTypeData = "data" - StreamTypeError = "error" + // Value for streamType header for data stream + StreamTypeData = "data" + // Value for streamType header for error stream + StreamTypeError = "error" + // Name of header that specifies the port being forwarded PortHeader = "port" + // Name of header that specifies a request ID used to associate the error + // and data streams for a single forwarded connection + PortForwardRequestIDHeader = "requestID" ) // Similarly to above, these are constants to support HTTP PATCH utilized by diff --git a/pkg/client/unversioned/portforward/portforward.go b/pkg/client/unversioned/portforward/portforward.go index 5efc21565d..925c5de402 100644 --- a/pkg/client/unversioned/portforward/portforward.go +++ b/pkg/client/unversioned/portforward/portforward.go @@ -25,10 +25,12 @@ import ( "net/http" "strconv" "strings" + "sync" "github.com/golang/glog" "k8s.io/kubernetes/pkg/api" client "k8s.io/kubernetes/pkg/client/unversioned" + "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/httpstream" "k8s.io/kubernetes/pkg/util/httpstream/spdy" ) @@ -51,10 +53,12 @@ type PortForwarder struct { ports []ForwardedPort stopChan <-chan struct{} - streamConn httpstream.Connection - listeners []io.Closer - upgrader upgrader - Ready chan struct{} + streamConn httpstream.Connection + listeners []io.Closer + upgrader upgrader + Ready chan struct{} + requestIDLock sync.Mutex + requestID int } // ForwardedPort contains a Local:Remote port pairing. @@ -145,7 +149,7 @@ func (pf *PortForwarder) ForwardPorts() error { var err error pf.streamConn, err = pf.upgrader.upgrade(pf.req, pf.config) if err != nil { - return fmt.Errorf("Error upgrading connection: %s", err) + return fmt.Errorf("error upgrading connection: %s", err) } defer pf.streamConn.Close() @@ -179,7 +183,7 @@ func (pf *PortForwarder) forward() error { select { case <-pf.stopChan: case <-pf.streamConn.CloseChan(): - glog.Errorf("Lost connection to pod") + util.HandleError(errors.New("lost connection to pod")) } return nil @@ -213,7 +217,7 @@ func (pf *PortForwarder) listenOnPortAndAddress(port *ForwardedPort, protocol st func (pf *PortForwarder) getListener(protocol string, hostname string, port *ForwardedPort) (net.Listener, error) { listener, err := net.Listen(protocol, fmt.Sprintf("%s:%d", hostname, port.Local)) if err != nil { - glog.Errorf("Unable to create listener: Error %s", err) + util.HandleError(fmt.Errorf("Unable to create listener: Error %s", err)) return nil, err } listenerAddress := listener.Addr().String() @@ -237,7 +241,7 @@ func (pf *PortForwarder) waitForConnection(listener net.Listener, port Forwarded if err != nil { // TODO consider using something like https://github.com/hydrogen18/stoppableListener? if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") { - glog.Errorf("Error accepting connection on port %d: %v", port.Local, err) + util.HandleError(fmt.Errorf("Error accepting connection on port %d: %v", port.Local, err)) } return } @@ -245,6 +249,14 @@ func (pf *PortForwarder) waitForConnection(listener net.Listener, port Forwarded } } +func (pf *PortForwarder) nextRequestID() int { + pf.requestIDLock.Lock() + defer pf.requestIDLock.Unlock() + id := pf.requestID + pf.requestID++ + return id +} + // handleConnection copies data between the local connection and the stream to // the remote server. func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) { @@ -252,65 +264,76 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) { glog.Infof("Handling connection for %d", port.Local) - errorChan := make(chan error) - doneChan := make(chan struct{}, 2) + requestID := pf.nextRequestID() // create error stream headers := http.Header{} headers.Set(api.StreamType, api.StreamTypeError) headers.Set(api.PortHeader, fmt.Sprintf("%d", port.Remote)) + headers.Set(api.PortForwardRequestIDHeader, strconv.Itoa(requestID)) errorStream, err := pf.streamConn.CreateStream(headers) if err != nil { - glog.Errorf("Error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err) + util.HandleError(fmt.Errorf("error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err)) return } - defer errorStream.Reset() + // we're not writing to this stream + errorStream.Close() + + errorChan := make(chan error) go func() { message, err := ioutil.ReadAll(errorStream) - if err != nil && err != io.EOF { - errorChan <- fmt.Errorf("Error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err) - } - if len(message) > 0 { - errorChan <- fmt.Errorf("An error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message)) + switch { + case err != nil: + errorChan <- fmt.Errorf("error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err) + case len(message) > 0: + errorChan <- fmt.Errorf("an error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message)) } + close(errorChan) }() // create data stream headers.Set(api.StreamType, api.StreamTypeData) dataStream, err := pf.streamConn.CreateStream(headers) if err != nil { - glog.Errorf("Error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err) + util.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err)) return } - // Send a Reset when this function exits to completely tear down the stream here - // and in the remote server. - defer dataStream.Reset() + + localError := make(chan struct{}) + remoteDone := make(chan struct{}) go func() { - // Copy from the remote side to the local port. We won't get an EOF from - // the server as it has no way of knowing when to close the stream. We'll - // take care of closing both ends of the stream with the call to - // stream.Reset() when this function exits. - if _, err := io.Copy(conn, dataStream); err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") { - glog.Errorf("Error copying from remote stream to local connection: %v", err) + // Copy from the remote side to the local port. + if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(err.Error(), "use of closed network connection") { + util.HandleError(fmt.Errorf("error copying from remote stream to local connection: %v", err)) } - doneChan <- struct{}{} + + // inform the select below that the remote copy is done + close(remoteDone) }() go func() { - // Copy from the local port to the remote side. Here we will be able to know - // when the Copy gets an EOF from conn, as that will happen as soon as conn is - // closed (i.e. client disconnected). - if _, err := io.Copy(dataStream, conn); err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") { - glog.Errorf("Error copying from local connection to remote stream: %v", err) + // inform server we're not sending any more data after copy unblocks + defer dataStream.Close() + + // Copy from the local port to the remote side. + if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(err.Error(), "use of closed network connection") { + util.HandleError(fmt.Errorf("error copying from local connection to remote stream: %v", err)) + // break out of the select below without waiting for the other copy to finish + close(localError) } - doneChan <- struct{}{} }() + // wait for either a local->remote error or for copying from remote->local to finish select { - case err := <-errorChan: - glog.Error(err) - case <-doneChan: + case <-remoteDone: + case <-localError: + } + + // always expect something on errorChan (it may be nil) + err = <-errorChan + if err != nil { + util.HandleError(err) } } @@ -318,7 +341,7 @@ func (pf *PortForwarder) Close() { // stop all listeners for _, l := range pf.listeners { if err := l.Close(); err != nil { - glog.Errorf("Error closing listener: %v", err) + util.HandleError(fmt.Errorf("error closing listener: %v", err)) } } } diff --git a/pkg/client/unversioned/portforward/portforward_test.go b/pkg/client/unversioned/portforward/portforward_test.go index f5e6c639fa..0bd2ed88ad 100644 --- a/pkg/client/unversioned/portforward/portforward_test.go +++ b/pkg/client/unversioned/portforward/portforward_test.go @@ -18,20 +18,21 @@ package portforward import ( "bytes" - "errors" "fmt" "io" "net" "net/http" + "net/http/httptest" + "net/url" "reflect" "strings" "sync" "testing" "time" - "k8s.io/kubernetes/pkg/api" client "k8s.io/kubernetes/pkg/client/unversioned" - "k8s.io/kubernetes/pkg/util/httpstream" + "k8s.io/kubernetes/pkg/kubelet" + "k8s.io/kubernetes/pkg/types" ) func TestParsePortsAndNew(t *testing.T) { @@ -110,109 +111,6 @@ func TestParsePortsAndNew(t *testing.T) { } } -type fakeUpgrader struct { - conn *fakeUpgradeConnection - err error -} - -func (u *fakeUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) { - return u.conn, u.err -} - -type fakeUpgradeConnection struct { - closeCalled bool - lock sync.Mutex - streams map[string]*fakeUpgradeStream - portData map[string]string -} - -func newFakeUpgradeConnection() *fakeUpgradeConnection { - return &fakeUpgradeConnection{ - streams: make(map[string]*fakeUpgradeStream), - portData: make(map[string]string), - } -} - -func (c *fakeUpgradeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) { - c.lock.Lock() - defer c.lock.Unlock() - - stream := &fakeUpgradeStream{} - c.streams[headers.Get(api.PortHeader)] = stream - // only simulate data on the data stream for now, not the error stream - if headers.Get(api.StreamType) == api.StreamTypeData { - stream.data = c.portData[headers.Get(api.PortHeader)] - } - - return stream, nil -} - -func (c *fakeUpgradeConnection) Close() error { - c.lock.Lock() - defer c.lock.Unlock() - - c.closeCalled = true - return nil -} - -func (c *fakeUpgradeConnection) CloseChan() <-chan bool { - return make(chan bool) -} - -func (c *fakeUpgradeConnection) SetIdleTimeout(timeout time.Duration) { -} - -type fakeUpgradeStream struct { - readCalled bool - writeCalled bool - dataWritten []byte - closeCalled bool - resetCalled bool - data string - lock sync.Mutex -} - -func (s *fakeUpgradeStream) Read(p []byte) (int, error) { - s.lock.Lock() - defer s.lock.Unlock() - s.readCalled = true - b := []byte(s.data) - n := copy(p, b) - // Indicate we returned all the data, and have no more data (EOF) - // Returning an EOF here will cause the port forwarder to immediately terminate, which is correct when we have no more data to send - return n, io.EOF -} - -func (s *fakeUpgradeStream) Write(p []byte) (int, error) { - s.lock.Lock() - defer s.lock.Unlock() - s.writeCalled = true - s.dataWritten = append(s.dataWritten, p...) - // Indicate the stream accepted all the data, and can accept more (no err) - // Returning an EOF here will cause the port forwarder to immediately terminate, which is incorrect, in case someone writes more data - return len(p), nil -} - -func (s *fakeUpgradeStream) Close() error { - s.lock.Lock() - defer s.lock.Unlock() - s.closeCalled = true - return nil -} - -func (s *fakeUpgradeStream) Reset() error { - s.lock.Lock() - defer s.lock.Unlock() - s.resetCalled = true - return nil -} - -func (s *fakeUpgradeStream) Headers() http.Header { - s.lock.Lock() - defer s.lock.Unlock() - return http.Header{} -} - type GetListenerTestCase struct { Hostname string Protocol string @@ -295,55 +193,119 @@ func TestGetListener(t *testing.T) { } } +// fakePortForwarder simulates port forwarding for testing. It implements +// kubelet.PortForwarder. +type fakePortForwarder struct { + lock sync.Mutex + // stores data received from the stream per port + received map[uint16]string + // data to be sent to the stream per port + send map[uint16]string +} + +var _ kubelet.PortForwarder = &fakePortForwarder{} + +func (pf *fakePortForwarder) PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error { + defer stream.Close() + + var wg sync.WaitGroup + + // client -> server + wg.Add(1) + go func() { + defer wg.Done() + + // copy from stream into a buffer + received := new(bytes.Buffer) + io.Copy(received, stream) + + // store the received content + pf.lock.Lock() + pf.received[port] = received.String() + pf.lock.Unlock() + }() + + // server -> client + wg.Add(1) + go func() { + defer wg.Done() + + // send the hardcoded data to the stream + io.Copy(stream, strings.NewReader(pf.send[port])) + }() + + wg.Wait() + + return nil +} + +// fakePortForwardServer creates an HTTP server that can handle port forwarding +// requests. +func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedFromClient map[uint16]string) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + pf := &fakePortForwarder{ + received: make(map[uint16]string), + send: serverSends, + } + kubelet.ServePortForward(w, req, pf, "pod", "uid", 0, 10*time.Second) + + for port, expected := range expectedFromClient { + actual, ok := pf.received[port] + if !ok { + t.Errorf("%s: server didn't receive any data for port %d", testName, port) + continue + } + + if expected != actual { + t.Errorf("%s: server expected to receive %q, got %q for port %d", testName, expected, actual, port) + } + } + + for port, actual := range pf.received { + if _, ok := expectedFromClient[port]; !ok { + t.Errorf("%s: server unexpectedly received %q for port %d", testName, actual, port) + } + } + }) +} + func TestForwardPorts(t *testing.T) { - testCases := []struct { - Upgrader *fakeUpgrader - Ports []string - Send map[uint16]string - Receive map[uint16]string - Err bool + tests := map[string]struct { + ports []string + clientSends map[uint16]string + serverSends map[uint16]string }{ - { - Upgrader: &fakeUpgrader{err: errors.New("bail")}, - Err: true, + "forward 1 port with no data either direction": { + ports: []string{"5000"}, }, - { - Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()}, - Ports: []string{"5000"}, - }, - { - Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()}, - Ports: []string{"5001", "6000"}, - Send: map[uint16]string{ + "forward 2 ports with bidirectional data": { + ports: []string{"5001", "6000"}, + clientSends: map[uint16]string{ 5001: "abcd", 6000: "ghij", }, - Receive: map[uint16]string{ + serverSends: map[uint16]string{ 5001: "1234", 6000: "5678", }, }, } - for i, testCase := range testCases { + for testName, test := range tests { + server := httptest.NewServer(fakePortForwardServer(t, testName, test.serverSends, test.clientSends)) + url, _ := url.ParseRequestURI(server.URL) + c := client.NewRESTClient(url, "x", nil, -1, -1) + req := c.Post().Resource("testing") + + conf := &client.Config{ + Host: server.URL, + } + stopChan := make(chan struct{}, 1) - pf, err := New(&client.Request{}, &client.Config{}, testCase.Ports, stopChan) - hasErr := err != nil - if hasErr != testCase.Err { - t.Fatalf("%d: New: expected %t, got %t: %v", i, testCase.Err, hasErr, err) - } - if pf == nil { - continue - } - pf.upgrader = testCase.Upgrader - if testCase.Upgrader.err != nil { - err := pf.ForwardPorts() - hasErr := err != nil - if hasErr != testCase.Err { - t.Fatalf("%d: ForwardPorts: expected %t, got %t: %v", i, testCase.Err, hasErr, err) - } - continue + pf, err := New(req, conf, test.ports, stopChan) + if err != nil { + t.Fatalf("%s: unexpected error calling New: %v", testName, err) } doneChan := make(chan error) @@ -352,65 +314,70 @@ func TestForwardPorts(t *testing.T) { }() <-pf.Ready - conn := testCase.Upgrader.conn - - for port, data := range testCase.Send { - conn.lock.Lock() - conn.portData[fmt.Sprintf("%d", port)] = testCase.Receive[port] - conn.lock.Unlock() - + for port, data := range test.clientSends { clientConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) if err != nil { - t.Fatalf("%d: error dialing %d: %s", i, port, err) + t.Errorf("%s: error dialing %d: %s", testName, port, err) + server.Close() + continue } defer clientConn.Close() n, err := clientConn.Write([]byte(data)) if err != nil && err != io.EOF { - t.Fatalf("%d: Error sending data '%s': %s", i, data, err) + t.Errorf("%s: Error sending data '%s': %s", testName, data, err) + server.Close() + continue } if n == 0 { - t.Fatalf("%d: unexpected write of 0 bytes", i) + t.Errorf("%s: unexpected write of 0 bytes", testName) + server.Close() + continue } b := make([]byte, 4) n, err = clientConn.Read(b) if err != nil && err != io.EOF { - t.Fatalf("%d: Error reading data: %s", i, err) + t.Errorf("%s: Error reading data: %s", testName, err) + server.Close() + continue } - if !bytes.Equal([]byte(testCase.Receive[port]), b) { - t.Fatalf("%d: expected to read '%s', got '%s'", i, testCase.Receive[port], b) + if !bytes.Equal([]byte(test.serverSends[port]), b) { + t.Errorf("%s: expected to read '%s', got '%s'", testName, test.serverSends[port], b) + server.Close() + continue } } - // tell r.ForwardPorts to stop close(stopChan) // wait for r.ForwardPorts to actually return err = <-doneChan if err != nil { - t.Fatalf("%d: unexpected error: %s", i, err) - } - - if e, a := len(testCase.Send), len(conn.streams); e != a { - t.Fatalf("%d: expected %d streams to be created, got %d", i, e, a) - } - - if !conn.closeCalled { - t.Fatalf("%d: expected conn closure", i) + t.Errorf("%s: unexpected error: %s", testName, err) } + server.Close() } } func TestForwardPortsReturnsErrorWhenAllBindsFailed(t *testing.T) { + server := httptest.NewServer(fakePortForwardServer(t, "allBindsFailed", nil, nil)) + defer server.Close() + url, _ := url.ParseRequestURI(server.URL) + c := client.NewRESTClient(url, "x", nil, -1, -1) + req := c.Post().Resource("testing") + + conf := &client.Config{ + Host: server.URL, + } + stopChan1 := make(chan struct{}, 1) defer close(stopChan1) - pf1, err := New(&client.Request{}, &client.Config{}, []string{"5555"}, stopChan1) + pf1, err := New(req, conf, []string{"5555"}, stopChan1) if err != nil { t.Fatalf("error creating pf1: %v", err) } - pf1.upgrader = &fakeUpgrader{conn: newFakeUpgradeConnection()} go pf1.ForwardPorts() <-pf1.Ready @@ -419,7 +386,6 @@ func TestForwardPortsReturnsErrorWhenAllBindsFailed(t *testing.T) { if err != nil { t.Fatalf("error creating pf2: %v", err) } - pf2.upgrader = &fakeUpgrader{conn: newFakeUpgradeConnection()} if err := pf2.ForwardPorts(); err == nil { t.Fatal("expected non-nil error for pf2.ForwardPorts") } diff --git a/pkg/kubelet/dockertools/manager.go b/pkg/kubelet/dockertools/manager.go index 74631ec452..7edf98c8b2 100644 --- a/pkg/kubelet/dockertools/manager.go +++ b/pkg/kubelet/dockertools/manager.go @@ -1179,16 +1179,39 @@ func (dm *DockerManager) PortForward(pod *kubecontainer.Pod, port uint16, stream } containerPid := container.State.Pid - // TODO what if the host doesn't have it??? - _, lookupErr := exec.LookPath("socat") + socatPath, lookupErr := exec.LookPath("socat") if lookupErr != nil { - return fmt.Errorf("Unable to do port forwarding: socat not found.") + return fmt.Errorf("unable to do port forwarding: socat not found.") } - args := []string{"-t", fmt.Sprintf("%d", containerPid), "-n", "socat", "-", fmt.Sprintf("TCP4:localhost:%d", port)} - // TODO use exec.LookPath - command := exec.Command("nsenter", args...) - command.Stdin = stream + + args := []string{"-t", fmt.Sprintf("%d", containerPid), "-n", socatPath, "-", fmt.Sprintf("TCP4:localhost:%d", port)} + + nsenterPath, lookupErr := exec.LookPath("nsenter") + if lookupErr != nil { + return fmt.Errorf("unable to do port forwarding: nsenter not found.") + } + + command := exec.Command(nsenterPath, args...) command.Stdout = stream + + // If we use Stdin, command.Run() won't return until the goroutine that's copying + // from stream finishes. Unfortunately, if you have a client like telnet connected + // via port forwarding, as long as the user's telnet client is connected to the user's + // local listener that port forwarding sets up, the telnet session never exits. This + // means that even if socat has finished running, command.Run() won't ever return + // (because the client still has the connection and stream open). + // + // The work around is to use StdinPipe(), as Wait() (called by Run()) closes the pipe + // when the command (socat) exits. + inPipe, err := command.StdinPipe() + if err != nil { + return fmt.Errorf("unable to do port forwarding: error creating stdin pipe: %v", err) + } + go func() { + io.Copy(inPipe, stream) + inPipe.Close() + }() + return command.Run() } diff --git a/pkg/kubelet/dockertools/manager_test.go b/pkg/kubelet/dockertools/manager_test.go index 93c5e7aaa0..ed92253ff4 100644 --- a/pkg/kubelet/dockertools/manager_test.go +++ b/pkg/kubelet/dockertools/manager_test.go @@ -1762,6 +1762,12 @@ func TestSyncPodEventHandlerFails(t *testing.T) { } } +type fakeReadWriteCloser struct{} + +func (*fakeReadWriteCloser) Read([]byte) (int, error) { return 0, nil } +func (*fakeReadWriteCloser) Write([]byte) (int, error) { return 0, nil } +func (*fakeReadWriteCloser) Close() error { return nil } + func TestPortForwardNoSuchContainer(t *testing.T) { dm, _ := newTestDockerManager() @@ -1774,7 +1780,8 @@ func TestPortForwardNoSuchContainer(t *testing.T) { Containers: nil, }, 5000, - nil, + // need a valid io.ReadWriteCloser here + &fakeReadWriteCloser{}, ) if err == nil { t.Fatal("unexpected non-error") diff --git a/pkg/kubelet/rkt/rkt.go b/pkg/kubelet/rkt/rkt.go index df2755222a..00475d8310 100644 --- a/pkg/kubelet/rkt/rkt.go +++ b/pkg/kubelet/rkt/rkt.go @@ -1211,19 +1211,39 @@ func (r *runtime) PortForward(pod *kubecontainer.Pod, port uint16, stream io.Rea return err } - _, lookupErr := exec.LookPath("socat") + socatPath, lookupErr := exec.LookPath("socat") if lookupErr != nil { return fmt.Errorf("unable to do port forwarding: socat not found.") } - args := []string{"-t", fmt.Sprintf("%d", info.pid), "-n", "socat", "-", fmt.Sprintf("TCP4:localhost:%d", port)} - _, lookupErr = exec.LookPath("nsenter") + args := []string{"-t", fmt.Sprintf("%d", info.pid), "-n", socatPath, "-", fmt.Sprintf("TCP4:localhost:%d", port)} + + nsenterPath, lookupErr := exec.LookPath("nsenter") if lookupErr != nil { return fmt.Errorf("unable to do port forwarding: nsenter not found.") } - command := exec.Command("nsenter", args...) - command.Stdin = stream + + command := exec.Command(nsenterPath, args...) command.Stdout = stream + + // If we use Stdin, command.Run() won't return until the goroutine that's copying + // from stream finishes. Unfortunately, if you have a client like telnet connected + // via port forwarding, as long as the user's telnet client is connected to the user's + // local listener that port forwarding sets up, the telnet session never exits. This + // means that even if socat has finished running, command.Run() won't ever return + // (because the client still has the connection and stream open). + // + // The work around is to use StdinPipe(), as Wait() (called by Run()) closes the pipe + // when the command (socat) exits. + inPipe, err := command.StdinPipe() + if err != nil { + return fmt.Errorf("unable to do port forwarding: error creating stdin pipe: %v", err) + } + go func() { + io.Copy(inPipe, stream) + inPipe.Close() + }() + return command.Run() } diff --git a/pkg/kubelet/server.go b/pkg/kubelet/server.go index b91fc97238..32c6cb3dbb 100644 --- a/pkg/kubelet/server.go +++ b/pkg/kubelet/server.go @@ -45,6 +45,7 @@ import ( "k8s.io/kubernetes/pkg/httplog" kubecontainer "k8s.io/kubernetes/pkg/kubelet/container" "k8s.io/kubernetes/pkg/types" + "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/flushwriter" "k8s.io/kubernetes/pkg/util/httpstream" "k8s.io/kubernetes/pkg/util/httpstream/spdy" @@ -458,7 +459,7 @@ func getContainerCoordinates(request *restful.Request) (namespace, pod string, u return } -const streamCreationTimeout = 30 * time.Second +const defaultStreamCreationTimeout = 30 * time.Second func (s *Server) getAttach(request *restful.Request, response *restful.Response) { podNamespace, podID, uid, container := getContainerCoordinates(request) @@ -564,7 +565,7 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout()) // TODO make it configurable? - expired := time.NewTimer(streamCreationTimeout) + expired := time.NewTimer(defaultStreamCreationTimeout) var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream receivedStreams := 0 @@ -612,6 +613,15 @@ func getPodCoordinates(request *restful.Request) (namespace, pod string, uid typ return } +// PortForwarder knows how to forward content from a data stream to/from a port +// in a pod. +type PortForwarder interface { + // PortForwarder copies data between a data stream and a port in a pod. + PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error +} + +// getPortForward handles a new restful port forward request. It determines the +// pod name and uid and then calls ServePortForward. func (s *Server) getPortForward(request *restful.Request, response *restful.Response) { podNamespace, podID, uid := getPodCoordinates(request) pod, ok := s.host.GetPodByName(podNamespace, podID) @@ -620,80 +630,280 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp return } + podName := kubecontainer.GetPodFullName(pod) + + ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), defaultStreamCreationTimeout) +} + +// ServePortForward handles a port forwarding request. A single request is +// kept alive as long as the client is still alive and the connection has not +// been timed out due to idleness. This function handles multiple forwarded +// connections; i.e., multiple `curl http://localhost:8888/` requests will be +// handled by a single invocation of ServePortForward. +func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) { streamChan := make(chan httpstream.Stream, 1) + + glog.V(5).Infof("Upgrading port forward response") upgrader := spdy.NewResponseUpgrader() - conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream) error { - portString := stream.Headers().Get(api.PortHeader) - port, err := strconv.ParseUint(portString, 10, 16) - if err != nil { - return fmt.Errorf("Unable to parse '%s' as a port: %v", portString, err) - } - if port < 1 { - return fmt.Errorf("Port '%d' must be greater than 0", port) - } - streamChan <- stream - return nil - }) + conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan)) if conn == nil { return } defer conn.Close() - conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout()) - var dataStreamLock sync.Mutex - dataStreamChans := make(map[string]chan httpstream.Stream) + glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout) + conn.SetIdleTimeout(idleTimeout) + h := &portForwardStreamHandler{ + conn: conn, + streamChan: streamChan, + streamPairs: make(map[string]*portForwardStreamPair), + streamCreationTimeout: streamCreationTimeout, + pod: podName, + uid: uid, + forwarder: portForwarder, + } + h.run() +} + +// portForwardStreamReceived is the httpstream.NewStreamHandler for port +// forward streams. It checks each stream's port and stream type headers, +// rejecting any streams that with missing or invalid values. Each valid +// stream is sent to the streams channel. +func portForwardStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream) error { + return func(stream httpstream.Stream) error { + // make sure it has a valid port header + portString := stream.Headers().Get(api.PortHeader) + if len(portString) == 0 { + return fmt.Errorf("%q header is required", api.PortHeader) + } + port, err := strconv.ParseUint(portString, 10, 16) + if err != nil { + return fmt.Errorf("unable to parse %q as a port: %v", portString, err) + } + if port < 1 { + return fmt.Errorf("port %q must be > 0", portString) + } + + // make sure it has a valid stream type header + streamType := stream.Headers().Get(api.StreamType) + if len(streamType) == 0 { + return fmt.Errorf("%q header is required", api.StreamType) + } + if streamType != api.StreamTypeError && streamType != api.StreamTypeData { + return fmt.Errorf("invalid stream type %q", streamType) + } + + streams <- stream + return nil + } +} + +// portForwardStreamHandler is capable of processing multiple port forward +// requests over a single httpstream.Connection. +type portForwardStreamHandler struct { + conn httpstream.Connection + streamChan chan httpstream.Stream + streamPairsLock sync.RWMutex + streamPairs map[string]*portForwardStreamPair + streamCreationTimeout time.Duration + pod string + uid types.UID + forwarder PortForwarder +} + +// getStreamPair returns a portForwardStreamPair for requestID. This creates a +// new pair if one does not yet exist for the requestID. The returned bool is +// true if the pair was created. +func (h *portForwardStreamHandler) getStreamPair(requestID string) (*portForwardStreamPair, bool) { + h.streamPairsLock.Lock() + defer h.streamPairsLock.Unlock() + + if p, ok := h.streamPairs[requestID]; ok { + glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID) + return p, false + } + + glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID) + + p := newPortForwardPair(requestID) + h.streamPairs[requestID] = p + + return p, true +} + +// monitorStreamPair waits for the pair to receive both its error and data +// streams, or for the timeout to expire (whichever happens first), and then +// removes the pair. +func (h *portForwardStreamHandler) monitorStreamPair(p *portForwardStreamPair, timeout <-chan time.Time) { + select { + case <-timeout: + err := fmt.Errorf("(conn=%p, request=%s) timed out waiting for streams", h.conn, p.requestID) + util.HandleError(err) + p.printError(err.Error()) + case <-p.complete: + glog.V(5).Infof("(conn=%p, request=%s) successfully received error and data streams", h.conn, p.requestID) + } + h.removeStreamPair(p.requestID) +} + +// hasStreamPair returns a bool indicating if a stream pair for requestID +// exists. +func (h *portForwardStreamHandler) hasStreamPair(requestID string) bool { + h.streamPairsLock.RLock() + defer h.streamPairsLock.RUnlock() + + _, ok := h.streamPairs[requestID] + return ok +} + +// removeStreamPair removes the stream pair identified by requestID from streamPairs. +func (h *portForwardStreamHandler) removeStreamPair(requestID string) { + h.streamPairsLock.Lock() + defer h.streamPairsLock.Unlock() + + delete(h.streamPairs, requestID) +} + +// requestID returns the request id for stream. +func (h *portForwardStreamHandler) requestID(stream httpstream.Stream) string { + requestID := stream.Headers().Get(api.PortForwardRequestIDHeader) + if len(requestID) == 0 { + glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader) + // If we get here, it's because the connection came from an older client + // that isn't generating the request id header + // (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287) + // + // This is a best-effort attempt at supporting older clients. + // + // When there aren't concurrent new forwarded connections, each connection + // will have a pair of streams (data, error), and the stream IDs will be + // consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert + // the stream ID into a pseudo-request id by taking the stream type and + // using id = stream.Identifier() when the stream type is error, + // and id = stream.Identifier() - 2 when it's data. + // + // NOTE: this only works when there are not concurrent new streams from + // multiple forwarded connections; it's a best-effort attempt at supporting + // old clients that don't generate request ids. If there are concurrent + // new connections, it's possible that 1 connection gets streams whose IDs + // are not consecutive (e.g. 5 and 9 instead of 5 and 7). + streamType := stream.Headers().Get(api.StreamType) + switch streamType { + case api.StreamTypeError: + requestID = strconv.Itoa(int(stream.Identifier())) + case api.StreamTypeData: + requestID = strconv.Itoa(int(stream.Identifier()) - 2) + } + + glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier()) + } + return requestID +} + +// run is the main loop for the portForwardStreamHandler. It processes new +// streams, invoking portForward for each complete stream pair. The loop exits +// when the httpstream.Connection is closed. +func (h *portForwardStreamHandler) run() { + glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn) Loop: for { select { - case <-conn.CloseChan(): + case <-h.conn.CloseChan(): + glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn) break Loop - case stream := <-streamChan: + case stream := <-h.streamChan: + requestID := h.requestID(stream) streamType := stream.Headers().Get(api.StreamType) - port := stream.Headers().Get(api.PortHeader) - dataStreamLock.Lock() - switch streamType { - case "error": - ch := make(chan httpstream.Stream) - dataStreamChans[port] = ch - go waitForPortForwardDataStreamAndRun(kubecontainer.GetPodFullName(pod), uid, stream, ch, s.host) - case "data": - ch, ok := dataStreamChans[port] - if ok { - ch <- stream - delete(dataStreamChans, port) - } else { - glog.Errorf("Unable to locate data stream channel for port %s", port) - } - default: - glog.Errorf("streamType header must be 'error' or 'data', got: '%s'", streamType) - stream.Reset() + glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType) + + p, created := h.getStreamPair(requestID) + if created { + go h.monitorStreamPair(p, time.After(h.streamCreationTimeout)) + } + if complete, err := p.add(stream); err != nil { + msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err) + util.HandleError(errors.New(msg)) + p.printError(msg) + } else if complete { + go h.portForward(p) } - dataStreamLock.Unlock() } } } -func waitForPortForwardDataStreamAndRun(pod string, uid types.UID, errorStream httpstream.Stream, dataStreamChan chan httpstream.Stream, host HostInterface) { - defer errorStream.Reset() +// portForward invokes the portForwardStreamHandler's forwarder.PortForward +// function for the given stream pair. +func (h *portForwardStreamHandler) portForward(p *portForwardStreamPair) { + defer p.dataStream.Close() + defer p.errorStream.Close() - var dataStream httpstream.Stream + portString := p.dataStream.Headers().Get(api.PortHeader) + port, _ := strconv.ParseUint(portString, 10, 16) - select { - case dataStream = <-dataStreamChan: - case <-time.After(streamCreationTimeout): - errorStream.Write([]byte("Timed out waiting for data stream")) - //TODO delete from dataStreamChans[port] - return + glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString) + err := h.forwarder.PortForward(h.pod, h.uid, uint16(port), p.dataStream) + glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString) + + if err != nil { + msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err) + util.HandleError(msg) + fmt.Fprint(p.errorStream, msg.Error()) + } +} + +// portForwardStreamPair represents the error and data streams for a port +// forwarding request. +type portForwardStreamPair struct { + lock sync.RWMutex + requestID string + dataStream httpstream.Stream + errorStream httpstream.Stream + complete chan struct{} +} + +// newPortForwardPair creates a new portForwardStreamPair. +func newPortForwardPair(requestID string) *portForwardStreamPair { + return &portForwardStreamPair{ + requestID: requestID, + complete: make(chan struct{}), + } +} + +// add adds the stream to the portForwardStreamPair. If the pair already +// contains a stream for the new stream's type, an error is returned. add +// returns true if both the data and error streams for this pair have been +// received. +func (p *portForwardStreamPair) add(stream httpstream.Stream) (bool, error) { + p.lock.Lock() + defer p.lock.Unlock() + + switch stream.Headers().Get(api.StreamType) { + case api.StreamTypeError: + if p.errorStream != nil { + return false, errors.New("error stream already assigned") + } + p.errorStream = stream + case api.StreamTypeData: + if p.dataStream != nil { + return false, errors.New("data stream already assigned") + } + p.dataStream = stream } - portString := dataStream.Headers().Get(api.PortHeader) - port, _ := strconv.ParseUint(portString, 10, 16) - err := host.PortForward(pod, uid, uint16(port), dataStream) - if err != nil { - msg := fmt.Errorf("Error forwarding port %d to pod %s, uid %v: %v", port, pod, uid, err) - glog.Error(msg) - errorStream.Write([]byte(msg.Error())) + complete := p.errorStream != nil && p.dataStream != nil + if complete { + close(p.complete) + } + return complete, nil +} + +// printError writes s to p.errorStream if p.errorStream has been set. +func (p *portForwardStreamPair) printError(s string) { + p.lock.RLock() + defer p.lock.RUnlock() + if p.errorStream != nil { + fmt.Fprint(p.errorStream, s) } } diff --git a/pkg/kubelet/server_test.go b/pkg/kubelet/server_test.go index b1301d87ad..664dd18e81 100644 --- a/pkg/kubelet/server_test.go +++ b/pkg/kubelet/server_test.go @@ -1426,3 +1426,221 @@ func TestServePortForward(t *testing.T) { <-portForwardFuncDone } } + +type fakeHttpStream struct { + headers http.Header + id uint32 +} + +func newFakeHttpStream() *fakeHttpStream { + return &fakeHttpStream{ + headers: make(http.Header), + } +} + +var _ httpstream.Stream = &fakeHttpStream{} + +func (s *fakeHttpStream) Read(data []byte) (int, error) { + return 0, nil +} + +func (s *fakeHttpStream) Write(data []byte) (int, error) { + return 0, nil +} + +func (s *fakeHttpStream) Close() error { + return nil +} + +func (s *fakeHttpStream) Reset() error { + return nil +} + +func (s *fakeHttpStream) Headers() http.Header { + return s.headers +} + +func (s *fakeHttpStream) Identifier() uint32 { + return s.id +} + +func TestPortForwardStreamReceived(t *testing.T) { + tests := map[string]struct { + port string + streamType string + expectedError string + }{ + "missing port": { + expectedError: `"port" header is required`, + }, + "unable to parse port": { + port: "abc", + expectedError: `unable to parse "abc" as a port: strconv.ParseUint: parsing "abc": invalid syntax`, + }, + "negative port": { + port: "-1", + expectedError: `unable to parse "-1" as a port: strconv.ParseUint: parsing "-1": invalid syntax`, + }, + "missing stream type": { + port: "80", + expectedError: `"streamType" header is required`, + }, + "valid port with error stream": { + port: "80", + streamType: "error", + }, + "valid port with data stream": { + port: "80", + streamType: "data", + }, + "invalid stream type": { + port: "80", + streamType: "foo", + expectedError: `invalid stream type "foo"`, + }, + } + for name, test := range tests { + streams := make(chan httpstream.Stream, 1) + f := portForwardStreamReceived(streams) + stream := newFakeHttpStream() + if len(test.port) > 0 { + stream.headers.Set("port", test.port) + } + if len(test.streamType) > 0 { + stream.headers.Set("streamType", test.streamType) + } + err := f(stream) + if len(test.expectedError) > 0 { + if err == nil { + t.Errorf("%s: expected err=%q, but it was nil", name, test.expectedError) + } + if e, a := test.expectedError, err.Error(); e != a { + t.Errorf("%s: expected err=%q, got %q", name, e, a) + } + continue + } + if err != nil { + t.Errorf("%s: unexpected error %v", name, err) + continue + } + if s := <-streams; s != stream { + t.Errorf("%s: expected stream %#v, got %#v", name, stream, s) + } + } +} + +func TestGetStreamPair(t *testing.T) { + timeout := make(chan time.Time) + + h := &portForwardStreamHandler{ + streamPairs: make(map[string]*portForwardStreamPair), + } + + // test adding a new entry + p, created := h.getStreamPair("1") + if p == nil { + t.Fatalf("unexpected nil pair") + } + if !created { + t.Fatal("expected created=true") + } + if p.dataStream != nil { + t.Errorf("unexpected non-nil data stream") + } + if p.errorStream != nil { + t.Errorf("unexpected non-nil error stream") + } + + // start the monitor for this pair + monitorDone := make(chan struct{}) + go func() { + h.monitorStreamPair(p, timeout) + close(monitorDone) + }() + + if !h.hasStreamPair("1") { + t.Fatal("This should still be true") + } + + // make sure we can retrieve an existing entry + p2, created := h.getStreamPair("1") + if created { + t.Fatal("expected created=false") + } + if p != p2 { + t.Fatalf("retrieving an existing pair: expected %#v, got %#v", p, p2) + } + + // removed via complete + dataStream := newFakeHttpStream() + dataStream.headers.Set(api.StreamType, api.StreamTypeData) + complete, err := p.add(dataStream) + if err != nil { + t.Fatalf("unexpected error adding data stream to pair: %v", err) + } + if complete { + t.Fatalf("unexpected complete") + } + + errorStream := newFakeHttpStream() + errorStream.headers.Set(api.StreamType, api.StreamTypeError) + complete, err = p.add(errorStream) + if err != nil { + t.Fatalf("unexpected error adding error stream to pair: %v", err) + } + if !complete { + t.Fatal("unexpected incomplete") + } + + // make sure monitorStreamPair completed + <-monitorDone + + // make sure the pair was removed + if h.hasStreamPair("1") { + t.Fatal("expected removal of pair after both data and error streams received") + } + + // removed via timeout + p, created = h.getStreamPair("2") + if !created { + t.Fatal("expected created=true") + } + if p == nil { + t.Fatal("expected p not to be nil") + } + monitorDone = make(chan struct{}) + go func() { + h.monitorStreamPair(p, timeout) + close(monitorDone) + }() + // cause the timeout + close(timeout) + // make sure monitorStreamPair completed + <-monitorDone + if h.hasStreamPair("2") { + t.Fatal("expected stream pair to be removed") + } +} + +func TestRequestID(t *testing.T) { + h := &portForwardStreamHandler{} + + s := newFakeHttpStream() + s.headers.Set(api.StreamType, api.StreamTypeError) + s.id = 1 + if e, a := "1", h.requestID(s); e != a { + t.Errorf("expected %q, got %q", e, a) + } + + s.headers.Set(api.StreamType, api.StreamTypeData) + s.id = 3 + if e, a := "1", h.requestID(s); e != a { + t.Errorf("expected %q, got %q", e, a) + } + + s.id = 7 + s.headers.Set(api.PortForwardRequestIDHeader, "2") + if e, a := "2", h.requestID(s); e != a { + t.Errorf("expected %q, got %q", e, a) + } +} diff --git a/pkg/util/httpstream/httpstream.go b/pkg/util/httpstream/httpstream.go index a1c92dde7f..b61af6224d 100644 --- a/pkg/util/httpstream/httpstream.go +++ b/pkg/util/httpstream/httpstream.go @@ -78,6 +78,8 @@ type Stream interface { Reset() error // Headers returns the headers used to create the stream. Headers() http.Header + // Identifier returns the stream's ID. + Identifier() uint32 } // IsUpgradeRequest returns true if the given request is a connection upgrade request diff --git a/test/e2e/kubectl.go b/test/e2e/kubectl.go index 447ee1984e..46b0cf2413 100644 --- a/test/e2e/kubectl.go +++ b/test/e2e/kubectl.go @@ -60,10 +60,7 @@ const ( simplePodPort = 80 ) -var ( - portForwardRegexp = regexp.MustCompile("Forwarding from 127.0.0.1:([0-9]+) -> 80") - proxyRegexp = regexp.MustCompile("Starting to serve on 127.0.0.1:([0-9]+)") -) +var proxyRegexp = regexp.MustCompile("Starting to serve on 127.0.0.1:([0-9]+)") var _ = Describe("Kubectl client", func() { defer GinkgoRecover() @@ -200,32 +197,11 @@ var _ = Describe("Kubectl client", func() { It("should support port-forward", func() { By("forwarding the container port to a local port") - cmd := kubectlCmd("port-forward", fmt.Sprintf("--namespace=%v", ns), simplePodName, fmt.Sprintf(":%d", simplePodPort)) + cmd, listenPort := runPortForward(ns, simplePodName, simplePodPort) defer tryKill(cmd) - // This is somewhat ugly but is the only way to retrieve the port that was picked - // by the port-forward command. We don't want to hard code the port as we have no - // way of guaranteeing we can pick one that isn't in use, particularly on Jenkins. - Logf("starting port-forward command and streaming output") - stdout, stderr, err := startCmdAndStreamOutput(cmd) - if err != nil { - Failf("Failed to start port-forward command: %v", err) - } - defer stdout.Close() - defer stderr.Close() - buf := make([]byte, 128) - var n int - Logf("reading from `kubectl port-forward` command's stderr") - if n, err = stderr.Read(buf); err != nil { - Failf("Failed to read from kubectl port-forward stderr: %v", err) - } - portForwardOutput := string(buf[:n]) - match := portForwardRegexp.FindStringSubmatch(portForwardOutput) - if len(match) != 2 { - Failf("Failed to parse kubectl port-forward output: %s", portForwardOutput) - } By("curling local port output") - localAddr := fmt.Sprintf("http://localhost:%s", match[1]) + localAddr := fmt.Sprintf("http://localhost:%d", listenPort) body, err := curl(localAddr) Logf("got: %s", body) if err != nil { diff --git a/test/e2e/portforward.go b/test/e2e/portforward.go new file mode 100644 index 0000000000..2b5a2ca5ba --- /dev/null +++ b/test/e2e/portforward.go @@ -0,0 +1,234 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package e2e + +import ( + "fmt" + "io/ioutil" + "net" + "os/exec" + "regexp" + "strconv" + "strings" + + "k8s.io/kubernetes/pkg/api" + + . "github.com/onsi/ginkgo" +) + +const ( + podName = "pfpod" +) + +// TODO support other ports besides 80 +var portForwardRegexp = regexp.MustCompile("Forwarding from 127.0.0.1:([0-9]+) -> 80") + +func pfPod(expectedClientData, chunks, chunkSize, chunkIntervalMillis string) *api.Pod { + return &api.Pod{ + ObjectMeta: api.ObjectMeta{ + Name: podName, + Labels: map[string]string{"name": podName}, + }, + Spec: api.PodSpec{ + Containers: []api.Container{ + { + Name: "portforwardtester", + Image: "gcr.io/google_containers/portforwardtester:1.0", + Env: []api.EnvVar{ + { + Name: "BIND_PORT", + Value: "80", + }, + { + Name: "EXPECTED_CLIENT_DATA", + Value: expectedClientData, + }, + { + Name: "CHUNKS", + Value: chunks, + }, + { + Name: "CHUNK_SIZE", + Value: chunkSize, + }, + { + Name: "CHUNK_INTERVAL", + Value: chunkIntervalMillis, + }, + }, + }, + }, + RestartPolicy: api.RestartPolicyNever, + }, + } +} + +func runPortForward(ns, podName string, port int) (*exec.Cmd, int) { + cmd := kubectlCmd("port-forward", fmt.Sprintf("--namespace=%v", ns), podName, fmt.Sprintf(":%d", port)) + // This is somewhat ugly but is the only way to retrieve the port that was picked + // by the port-forward command. We don't want to hard code the port as we have no + // way of guaranteeing we can pick one that isn't in use, particularly on Jenkins. + Logf("starting port-forward command and streaming output") + stdout, stderr, err := startCmdAndStreamOutput(cmd) + if err != nil { + Failf("Failed to start port-forward command: %v", err) + } + defer stdout.Close() + defer stderr.Close() + + buf := make([]byte, 128) + var n int + Logf("reading from `kubectl port-forward` command's stderr") + if n, err = stderr.Read(buf); err != nil { + Failf("Failed to read from kubectl port-forward stderr: %v", err) + } + portForwardOutput := string(buf[:n]) + match := portForwardRegexp.FindStringSubmatch(portForwardOutput) + if len(match) != 2 { + Failf("Failed to parse kubectl port-forward output: %s", portForwardOutput) + } + + listenPort, err := strconv.Atoi(match[1]) + if err != nil { + Failf("Error converting %s to an int: %v", match[1], err) + } + + return cmd, listenPort +} + +var _ = Describe("Port forwarding", func() { + framework := NewFramework("port-forwarding") + + Describe("With a server that expects a client request", func() { + It("should support a client that connects, sends no data, and disconnects", func() { + By("creating the target pod") + pod := pfPod("abc", "1", "1", "1") + framework.Client.Pods(framework.Namespace.Name).Create(pod) + framework.WaitForPodRunning(pod.Name) + + By("Running 'kubectl port-forward'") + cmd, listenPort := runPortForward(framework.Namespace.Name, pod.Name, 80) + defer tryKill(cmd) + + By("Dialing the local port") + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenPort)) + if err != nil { + Failf("Couldn't connect to port %d: %v", listenPort, err) + } + + By("Closing the connection to the local port") + conn.Close() + + logOutput := runKubectl("logs", fmt.Sprintf("--namespace=%v", framework.Namespace.Name), "-f", podName) + verifyLogMessage(logOutput, "Accepted client connection") + verifyLogMessage(logOutput, "Expected to read 3 bytes from client, but got 0 instead") + }) + + It("should support a client that connects, sends data, and disconnects", func() { + By("creating the target pod") + pod := pfPod("abc", "10", "10", "100") + framework.Client.Pods(framework.Namespace.Name).Create(pod) + framework.WaitForPodRunning(pod.Name) + + By("Running 'kubectl port-forward'") + cmd, listenPort := runPortForward(framework.Namespace.Name, pod.Name, 80) + defer tryKill(cmd) + + By("Dialing the local port") + addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", listenPort)) + if err != nil { + Failf("Error resolving tcp addr: %v", err) + } + conn, err := net.DialTCP("tcp", nil, addr) + if err != nil { + Failf("Couldn't connect to port %d: %v", listenPort, err) + } + defer func() { + By("Closing the connection to the local port") + conn.Close() + }() + + By("Sending the expected data to the local port") + fmt.Fprint(conn, "abc") + + By("Closing the write half of the client's connection") + conn.CloseWrite() + + By("Reading data from the local port") + fromServer, err := ioutil.ReadAll(conn) + if err != nil { + Failf("Unexpected error reading data from the server: %v", err) + } + + if e, a := strings.Repeat("x", 100), string(fromServer); e != a { + Failf("Expected %q from server, got %q", e, a) + } + + logOutput := runKubectl("logs", fmt.Sprintf("--namespace=%v", framework.Namespace.Name), "-f", podName) + verifyLogMessage(logOutput, "^Accepted client connection$") + verifyLogMessage(logOutput, "^Received expected client data$") + verifyLogMessage(logOutput, "^Done$") + }) + }) + Describe("With a server that expects no client request", func() { + It("should support a client that connects, sends no data, and disconnects", func() { + By("creating the target pod") + pod := pfPod("", "10", "10", "100") + framework.Client.Pods(framework.Namespace.Name).Create(pod) + framework.WaitForPodRunning(pod.Name) + + By("Running 'kubectl port-forward'") + cmd, listenPort := runPortForward(framework.Namespace.Name, pod.Name, 80) + defer tryKill(cmd) + + By("Dialing the local port") + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenPort)) + if err != nil { + Failf("Couldn't connect to port %d: %v", listenPort, err) + } + defer func() { + By("Closing the connection to the local port") + conn.Close() + }() + + By("Reading data from the local port") + fromServer, err := ioutil.ReadAll(conn) + if err != nil { + Failf("Unexpected error reading data from the server: %v", err) + } + + if e, a := strings.Repeat("x", 100), string(fromServer); e != a { + Failf("Expected %q from server, got %q", e, a) + } + + logOutput := runKubectl("logs", fmt.Sprintf("--namespace=%v", framework.Namespace.Name), "-f", podName) + verifyLogMessage(logOutput, "Accepted client connection") + verifyLogMessage(logOutput, "Done") + }) + }) +}) + +func verifyLogMessage(log, expected string) { + re := regexp.MustCompile(expected) + lines := strings.Split(log, "\n") + for i := range lines { + if re.MatchString(lines[i]) { + return + } + } + Failf("Missing %q from log: %s", expected, log) +}