Port forwarding fixes

Correct port-forward data copying logic so that the server closes its
half of the data stream when socat exits, and the client closes its half
of the data stream when it finishes writing.

Modify the client to wait for both copies (client->server,
server->client) to finish before it unblocks.

Fix race condition in the Kubelet's handling of incoming port forward
streams. Have the client generate a connectionID header to be used to
associate the error and data streams for a single connection, instead of
assuming that streams n and n+1 go together. Attempt to generate a
pseudo connectionID in the server in the event the connectionID header
isn't present (older clients); this is a best-effort approach that only
really works with 1 connection at a time, whereas multiple concurrent
connections will only work reliably with a newer client that is
generating connectionID.
pull/6/head
Andy Goldstein 2015-09-22 16:29:51 -04:00
parent 7f900daa3e
commit ed021fed4c
11 changed files with 992 additions and 303 deletions

View File

@ -1951,14 +1951,24 @@ const (
// Command to run for remote command execution
ExecCommandParamm = "command"
StreamType = "streamType"
StreamTypeStdin = "stdin"
// Name of header that specifies stream type
StreamType = "streamType"
// Value for streamType header for stdin stream
StreamTypeStdin = "stdin"
// Value for streamType header for stdout stream
StreamTypeStdout = "stdout"
// Value for streamType header for stderr stream
StreamTypeStderr = "stderr"
StreamTypeData = "data"
StreamTypeError = "error"
// Value for streamType header for data stream
StreamTypeData = "data"
// Value for streamType header for error stream
StreamTypeError = "error"
// Name of header that specifies the port being forwarded
PortHeader = "port"
// Name of header that specifies a request ID used to associate the error
// and data streams for a single forwarded connection
PortForwardRequestIDHeader = "requestID"
)
// Similarly to above, these are constants to support HTTP PATCH utilized by

View File

@ -25,10 +25,12 @@ import (
"net/http"
"strconv"
"strings"
"sync"
"github.com/golang/glog"
"k8s.io/kubernetes/pkg/api"
client "k8s.io/kubernetes/pkg/client/unversioned"
"k8s.io/kubernetes/pkg/util"
"k8s.io/kubernetes/pkg/util/httpstream"
"k8s.io/kubernetes/pkg/util/httpstream/spdy"
)
@ -51,10 +53,12 @@ type PortForwarder struct {
ports []ForwardedPort
stopChan <-chan struct{}
streamConn httpstream.Connection
listeners []io.Closer
upgrader upgrader
Ready chan struct{}
streamConn httpstream.Connection
listeners []io.Closer
upgrader upgrader
Ready chan struct{}
requestIDLock sync.Mutex
requestID int
}
// ForwardedPort contains a Local:Remote port pairing.
@ -145,7 +149,7 @@ func (pf *PortForwarder) ForwardPorts() error {
var err error
pf.streamConn, err = pf.upgrader.upgrade(pf.req, pf.config)
if err != nil {
return fmt.Errorf("Error upgrading connection: %s", err)
return fmt.Errorf("error upgrading connection: %s", err)
}
defer pf.streamConn.Close()
@ -179,7 +183,7 @@ func (pf *PortForwarder) forward() error {
select {
case <-pf.stopChan:
case <-pf.streamConn.CloseChan():
glog.Errorf("Lost connection to pod")
util.HandleError(errors.New("lost connection to pod"))
}
return nil
@ -213,7 +217,7 @@ func (pf *PortForwarder) listenOnPortAndAddress(port *ForwardedPort, protocol st
func (pf *PortForwarder) getListener(protocol string, hostname string, port *ForwardedPort) (net.Listener, error) {
listener, err := net.Listen(protocol, fmt.Sprintf("%s:%d", hostname, port.Local))
if err != nil {
glog.Errorf("Unable to create listener: Error %s", err)
util.HandleError(fmt.Errorf("Unable to create listener: Error %s", err))
return nil, err
}
listenerAddress := listener.Addr().String()
@ -237,7 +241,7 @@ func (pf *PortForwarder) waitForConnection(listener net.Listener, port Forwarded
if err != nil {
// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") {
glog.Errorf("Error accepting connection on port %d: %v", port.Local, err)
util.HandleError(fmt.Errorf("Error accepting connection on port %d: %v", port.Local, err))
}
return
}
@ -245,6 +249,14 @@ func (pf *PortForwarder) waitForConnection(listener net.Listener, port Forwarded
}
}
func (pf *PortForwarder) nextRequestID() int {
pf.requestIDLock.Lock()
defer pf.requestIDLock.Unlock()
id := pf.requestID
pf.requestID++
return id
}
// handleConnection copies data between the local connection and the stream to
// the remote server.
func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
@ -252,65 +264,76 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
glog.Infof("Handling connection for %d", port.Local)
errorChan := make(chan error)
doneChan := make(chan struct{}, 2)
requestID := pf.nextRequestID()
// create error stream
headers := http.Header{}
headers.Set(api.StreamType, api.StreamTypeError)
headers.Set(api.PortHeader, fmt.Sprintf("%d", port.Remote))
headers.Set(api.PortForwardRequestIDHeader, strconv.Itoa(requestID))
errorStream, err := pf.streamConn.CreateStream(headers)
if err != nil {
glog.Errorf("Error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err)
util.HandleError(fmt.Errorf("error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err))
return
}
defer errorStream.Reset()
// we're not writing to this stream
errorStream.Close()
errorChan := make(chan error)
go func() {
message, err := ioutil.ReadAll(errorStream)
if err != nil && err != io.EOF {
errorChan <- fmt.Errorf("Error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err)
}
if len(message) > 0 {
errorChan <- fmt.Errorf("An error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message))
switch {
case err != nil:
errorChan <- fmt.Errorf("error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err)
case len(message) > 0:
errorChan <- fmt.Errorf("an error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message))
}
close(errorChan)
}()
// create data stream
headers.Set(api.StreamType, api.StreamTypeData)
dataStream, err := pf.streamConn.CreateStream(headers)
if err != nil {
glog.Errorf("Error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err)
util.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err))
return
}
// Send a Reset when this function exits to completely tear down the stream here
// and in the remote server.
defer dataStream.Reset()
localError := make(chan struct{})
remoteDone := make(chan struct{})
go func() {
// Copy from the remote side to the local port. We won't get an EOF from
// the server as it has no way of knowing when to close the stream. We'll
// take care of closing both ends of the stream with the call to
// stream.Reset() when this function exits.
if _, err := io.Copy(conn, dataStream); err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
glog.Errorf("Error copying from remote stream to local connection: %v", err)
// Copy from the remote side to the local port.
if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
util.HandleError(fmt.Errorf("error copying from remote stream to local connection: %v", err))
}
doneChan <- struct{}{}
// inform the select below that the remote copy is done
close(remoteDone)
}()
go func() {
// Copy from the local port to the remote side. Here we will be able to know
// when the Copy gets an EOF from conn, as that will happen as soon as conn is
// closed (i.e. client disconnected).
if _, err := io.Copy(dataStream, conn); err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
glog.Errorf("Error copying from local connection to remote stream: %v", err)
// inform server we're not sending any more data after copy unblocks
defer dataStream.Close()
// Copy from the local port to the remote side.
if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
util.HandleError(fmt.Errorf("error copying from local connection to remote stream: %v", err))
// break out of the select below without waiting for the other copy to finish
close(localError)
}
doneChan <- struct{}{}
}()
// wait for either a local->remote error or for copying from remote->local to finish
select {
case err := <-errorChan:
glog.Error(err)
case <-doneChan:
case <-remoteDone:
case <-localError:
}
// always expect something on errorChan (it may be nil)
err = <-errorChan
if err != nil {
util.HandleError(err)
}
}
@ -318,7 +341,7 @@ func (pf *PortForwarder) Close() {
// stop all listeners
for _, l := range pf.listeners {
if err := l.Close(); err != nil {
glog.Errorf("Error closing listener: %v", err)
util.HandleError(fmt.Errorf("error closing listener: %v", err))
}
}
}

View File

@ -18,20 +18,21 @@ package portforward
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"sync"
"testing"
"time"
"k8s.io/kubernetes/pkg/api"
client "k8s.io/kubernetes/pkg/client/unversioned"
"k8s.io/kubernetes/pkg/util/httpstream"
"k8s.io/kubernetes/pkg/kubelet"
"k8s.io/kubernetes/pkg/types"
)
func TestParsePortsAndNew(t *testing.T) {
@ -110,109 +111,6 @@ func TestParsePortsAndNew(t *testing.T) {
}
}
type fakeUpgrader struct {
conn *fakeUpgradeConnection
err error
}
func (u *fakeUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) {
return u.conn, u.err
}
type fakeUpgradeConnection struct {
closeCalled bool
lock sync.Mutex
streams map[string]*fakeUpgradeStream
portData map[string]string
}
func newFakeUpgradeConnection() *fakeUpgradeConnection {
return &fakeUpgradeConnection{
streams: make(map[string]*fakeUpgradeStream),
portData: make(map[string]string),
}
}
func (c *fakeUpgradeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
c.lock.Lock()
defer c.lock.Unlock()
stream := &fakeUpgradeStream{}
c.streams[headers.Get(api.PortHeader)] = stream
// only simulate data on the data stream for now, not the error stream
if headers.Get(api.StreamType) == api.StreamTypeData {
stream.data = c.portData[headers.Get(api.PortHeader)]
}
return stream, nil
}
func (c *fakeUpgradeConnection) Close() error {
c.lock.Lock()
defer c.lock.Unlock()
c.closeCalled = true
return nil
}
func (c *fakeUpgradeConnection) CloseChan() <-chan bool {
return make(chan bool)
}
func (c *fakeUpgradeConnection) SetIdleTimeout(timeout time.Duration) {
}
type fakeUpgradeStream struct {
readCalled bool
writeCalled bool
dataWritten []byte
closeCalled bool
resetCalled bool
data string
lock sync.Mutex
}
func (s *fakeUpgradeStream) Read(p []byte) (int, error) {
s.lock.Lock()
defer s.lock.Unlock()
s.readCalled = true
b := []byte(s.data)
n := copy(p, b)
// Indicate we returned all the data, and have no more data (EOF)
// Returning an EOF here will cause the port forwarder to immediately terminate, which is correct when we have no more data to send
return n, io.EOF
}
func (s *fakeUpgradeStream) Write(p []byte) (int, error) {
s.lock.Lock()
defer s.lock.Unlock()
s.writeCalled = true
s.dataWritten = append(s.dataWritten, p...)
// Indicate the stream accepted all the data, and can accept more (no err)
// Returning an EOF here will cause the port forwarder to immediately terminate, which is incorrect, in case someone writes more data
return len(p), nil
}
func (s *fakeUpgradeStream) Close() error {
s.lock.Lock()
defer s.lock.Unlock()
s.closeCalled = true
return nil
}
func (s *fakeUpgradeStream) Reset() error {
s.lock.Lock()
defer s.lock.Unlock()
s.resetCalled = true
return nil
}
func (s *fakeUpgradeStream) Headers() http.Header {
s.lock.Lock()
defer s.lock.Unlock()
return http.Header{}
}
type GetListenerTestCase struct {
Hostname string
Protocol string
@ -295,55 +193,119 @@ func TestGetListener(t *testing.T) {
}
}
// fakePortForwarder simulates port forwarding for testing. It implements
// kubelet.PortForwarder.
type fakePortForwarder struct {
lock sync.Mutex
// stores data received from the stream per port
received map[uint16]string
// data to be sent to the stream per port
send map[uint16]string
}
var _ kubelet.PortForwarder = &fakePortForwarder{}
func (pf *fakePortForwarder) PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
defer stream.Close()
var wg sync.WaitGroup
// client -> server
wg.Add(1)
go func() {
defer wg.Done()
// copy from stream into a buffer
received := new(bytes.Buffer)
io.Copy(received, stream)
// store the received content
pf.lock.Lock()
pf.received[port] = received.String()
pf.lock.Unlock()
}()
// server -> client
wg.Add(1)
go func() {
defer wg.Done()
// send the hardcoded data to the stream
io.Copy(stream, strings.NewReader(pf.send[port]))
}()
wg.Wait()
return nil
}
// fakePortForwardServer creates an HTTP server that can handle port forwarding
// requests.
func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedFromClient map[uint16]string) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
pf := &fakePortForwarder{
received: make(map[uint16]string),
send: serverSends,
}
kubelet.ServePortForward(w, req, pf, "pod", "uid", 0, 10*time.Second)
for port, expected := range expectedFromClient {
actual, ok := pf.received[port]
if !ok {
t.Errorf("%s: server didn't receive any data for port %d", testName, port)
continue
}
if expected != actual {
t.Errorf("%s: server expected to receive %q, got %q for port %d", testName, expected, actual, port)
}
}
for port, actual := range pf.received {
if _, ok := expectedFromClient[port]; !ok {
t.Errorf("%s: server unexpectedly received %q for port %d", testName, actual, port)
}
}
})
}
func TestForwardPorts(t *testing.T) {
testCases := []struct {
Upgrader *fakeUpgrader
Ports []string
Send map[uint16]string
Receive map[uint16]string
Err bool
tests := map[string]struct {
ports []string
clientSends map[uint16]string
serverSends map[uint16]string
}{
{
Upgrader: &fakeUpgrader{err: errors.New("bail")},
Err: true,
"forward 1 port with no data either direction": {
ports: []string{"5000"},
},
{
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
Ports: []string{"5000"},
},
{
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
Ports: []string{"5001", "6000"},
Send: map[uint16]string{
"forward 2 ports with bidirectional data": {
ports: []string{"5001", "6000"},
clientSends: map[uint16]string{
5001: "abcd",
6000: "ghij",
},
Receive: map[uint16]string{
serverSends: map[uint16]string{
5001: "1234",
6000: "5678",
},
},
}
for i, testCase := range testCases {
for testName, test := range tests {
server := httptest.NewServer(fakePortForwardServer(t, testName, test.serverSends, test.clientSends))
url, _ := url.ParseRequestURI(server.URL)
c := client.NewRESTClient(url, "x", nil, -1, -1)
req := c.Post().Resource("testing")
conf := &client.Config{
Host: server.URL,
}
stopChan := make(chan struct{}, 1)
pf, err := New(&client.Request{}, &client.Config{}, testCase.Ports, stopChan)
hasErr := err != nil
if hasErr != testCase.Err {
t.Fatalf("%d: New: expected %t, got %t: %v", i, testCase.Err, hasErr, err)
}
if pf == nil {
continue
}
pf.upgrader = testCase.Upgrader
if testCase.Upgrader.err != nil {
err := pf.ForwardPorts()
hasErr := err != nil
if hasErr != testCase.Err {
t.Fatalf("%d: ForwardPorts: expected %t, got %t: %v", i, testCase.Err, hasErr, err)
}
continue
pf, err := New(req, conf, test.ports, stopChan)
if err != nil {
t.Fatalf("%s: unexpected error calling New: %v", testName, err)
}
doneChan := make(chan error)
@ -352,65 +314,70 @@ func TestForwardPorts(t *testing.T) {
}()
<-pf.Ready
conn := testCase.Upgrader.conn
for port, data := range testCase.Send {
conn.lock.Lock()
conn.portData[fmt.Sprintf("%d", port)] = testCase.Receive[port]
conn.lock.Unlock()
for port, data := range test.clientSends {
clientConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
t.Fatalf("%d: error dialing %d: %s", i, port, err)
t.Errorf("%s: error dialing %d: %s", testName, port, err)
server.Close()
continue
}
defer clientConn.Close()
n, err := clientConn.Write([]byte(data))
if err != nil && err != io.EOF {
t.Fatalf("%d: Error sending data '%s': %s", i, data, err)
t.Errorf("%s: Error sending data '%s': %s", testName, data, err)
server.Close()
continue
}
if n == 0 {
t.Fatalf("%d: unexpected write of 0 bytes", i)
t.Errorf("%s: unexpected write of 0 bytes", testName)
server.Close()
continue
}
b := make([]byte, 4)
n, err = clientConn.Read(b)
if err != nil && err != io.EOF {
t.Fatalf("%d: Error reading data: %s", i, err)
t.Errorf("%s: Error reading data: %s", testName, err)
server.Close()
continue
}
if !bytes.Equal([]byte(testCase.Receive[port]), b) {
t.Fatalf("%d: expected to read '%s', got '%s'", i, testCase.Receive[port], b)
if !bytes.Equal([]byte(test.serverSends[port]), b) {
t.Errorf("%s: expected to read '%s', got '%s'", testName, test.serverSends[port], b)
server.Close()
continue
}
}
// tell r.ForwardPorts to stop
close(stopChan)
// wait for r.ForwardPorts to actually return
err = <-doneChan
if err != nil {
t.Fatalf("%d: unexpected error: %s", i, err)
}
if e, a := len(testCase.Send), len(conn.streams); e != a {
t.Fatalf("%d: expected %d streams to be created, got %d", i, e, a)
}
if !conn.closeCalled {
t.Fatalf("%d: expected conn closure", i)
t.Errorf("%s: unexpected error: %s", testName, err)
}
server.Close()
}
}
func TestForwardPortsReturnsErrorWhenAllBindsFailed(t *testing.T) {
server := httptest.NewServer(fakePortForwardServer(t, "allBindsFailed", nil, nil))
defer server.Close()
url, _ := url.ParseRequestURI(server.URL)
c := client.NewRESTClient(url, "x", nil, -1, -1)
req := c.Post().Resource("testing")
conf := &client.Config{
Host: server.URL,
}
stopChan1 := make(chan struct{}, 1)
defer close(stopChan1)
pf1, err := New(&client.Request{}, &client.Config{}, []string{"5555"}, stopChan1)
pf1, err := New(req, conf, []string{"5555"}, stopChan1)
if err != nil {
t.Fatalf("error creating pf1: %v", err)
}
pf1.upgrader = &fakeUpgrader{conn: newFakeUpgradeConnection()}
go pf1.ForwardPorts()
<-pf1.Ready
@ -419,7 +386,6 @@ func TestForwardPortsReturnsErrorWhenAllBindsFailed(t *testing.T) {
if err != nil {
t.Fatalf("error creating pf2: %v", err)
}
pf2.upgrader = &fakeUpgrader{conn: newFakeUpgradeConnection()}
if err := pf2.ForwardPorts(); err == nil {
t.Fatal("expected non-nil error for pf2.ForwardPorts")
}

View File

@ -1179,16 +1179,39 @@ func (dm *DockerManager) PortForward(pod *kubecontainer.Pod, port uint16, stream
}
containerPid := container.State.Pid
// TODO what if the host doesn't have it???
_, lookupErr := exec.LookPath("socat")
socatPath, lookupErr := exec.LookPath("socat")
if lookupErr != nil {
return fmt.Errorf("Unable to do port forwarding: socat not found.")
return fmt.Errorf("unable to do port forwarding: socat not found.")
}
args := []string{"-t", fmt.Sprintf("%d", containerPid), "-n", "socat", "-", fmt.Sprintf("TCP4:localhost:%d", port)}
// TODO use exec.LookPath
command := exec.Command("nsenter", args...)
command.Stdin = stream
args := []string{"-t", fmt.Sprintf("%d", containerPid), "-n", socatPath, "-", fmt.Sprintf("TCP4:localhost:%d", port)}
nsenterPath, lookupErr := exec.LookPath("nsenter")
if lookupErr != nil {
return fmt.Errorf("unable to do port forwarding: nsenter not found.")
}
command := exec.Command(nsenterPath, args...)
command.Stdout = stream
// If we use Stdin, command.Run() won't return until the goroutine that's copying
// from stream finishes. Unfortunately, if you have a client like telnet connected
// via port forwarding, as long as the user's telnet client is connected to the user's
// local listener that port forwarding sets up, the telnet session never exits. This
// means that even if socat has finished running, command.Run() won't ever return
// (because the client still has the connection and stream open).
//
// The work around is to use StdinPipe(), as Wait() (called by Run()) closes the pipe
// when the command (socat) exits.
inPipe, err := command.StdinPipe()
if err != nil {
return fmt.Errorf("unable to do port forwarding: error creating stdin pipe: %v", err)
}
go func() {
io.Copy(inPipe, stream)
inPipe.Close()
}()
return command.Run()
}

View File

@ -1762,6 +1762,12 @@ func TestSyncPodEventHandlerFails(t *testing.T) {
}
}
type fakeReadWriteCloser struct{}
func (*fakeReadWriteCloser) Read([]byte) (int, error) { return 0, nil }
func (*fakeReadWriteCloser) Write([]byte) (int, error) { return 0, nil }
func (*fakeReadWriteCloser) Close() error { return nil }
func TestPortForwardNoSuchContainer(t *testing.T) {
dm, _ := newTestDockerManager()
@ -1774,7 +1780,8 @@ func TestPortForwardNoSuchContainer(t *testing.T) {
Containers: nil,
},
5000,
nil,
// need a valid io.ReadWriteCloser here
&fakeReadWriteCloser{},
)
if err == nil {
t.Fatal("unexpected non-error")

View File

@ -1211,19 +1211,39 @@ func (r *runtime) PortForward(pod *kubecontainer.Pod, port uint16, stream io.Rea
return err
}
_, lookupErr := exec.LookPath("socat")
socatPath, lookupErr := exec.LookPath("socat")
if lookupErr != nil {
return fmt.Errorf("unable to do port forwarding: socat not found.")
}
args := []string{"-t", fmt.Sprintf("%d", info.pid), "-n", "socat", "-", fmt.Sprintf("TCP4:localhost:%d", port)}
_, lookupErr = exec.LookPath("nsenter")
args := []string{"-t", fmt.Sprintf("%d", info.pid), "-n", socatPath, "-", fmt.Sprintf("TCP4:localhost:%d", port)}
nsenterPath, lookupErr := exec.LookPath("nsenter")
if lookupErr != nil {
return fmt.Errorf("unable to do port forwarding: nsenter not found.")
}
command := exec.Command("nsenter", args...)
command.Stdin = stream
command := exec.Command(nsenterPath, args...)
command.Stdout = stream
// If we use Stdin, command.Run() won't return until the goroutine that's copying
// from stream finishes. Unfortunately, if you have a client like telnet connected
// via port forwarding, as long as the user's telnet client is connected to the user's
// local listener that port forwarding sets up, the telnet session never exits. This
// means that even if socat has finished running, command.Run() won't ever return
// (because the client still has the connection and stream open).
//
// The work around is to use StdinPipe(), as Wait() (called by Run()) closes the pipe
// when the command (socat) exits.
inPipe, err := command.StdinPipe()
if err != nil {
return fmt.Errorf("unable to do port forwarding: error creating stdin pipe: %v", err)
}
go func() {
io.Copy(inPipe, stream)
inPipe.Close()
}()
return command.Run()
}

View File

@ -45,6 +45,7 @@ import (
"k8s.io/kubernetes/pkg/httplog"
kubecontainer "k8s.io/kubernetes/pkg/kubelet/container"
"k8s.io/kubernetes/pkg/types"
"k8s.io/kubernetes/pkg/util"
"k8s.io/kubernetes/pkg/util/flushwriter"
"k8s.io/kubernetes/pkg/util/httpstream"
"k8s.io/kubernetes/pkg/util/httpstream/spdy"
@ -458,7 +459,7 @@ func getContainerCoordinates(request *restful.Request) (namespace, pod string, u
return
}
const streamCreationTimeout = 30 * time.Second
const defaultStreamCreationTimeout = 30 * time.Second
func (s *Server) getAttach(request *restful.Request, response *restful.Response) {
podNamespace, podID, uid, container := getContainerCoordinates(request)
@ -564,7 +565,7 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo
conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout())
// TODO make it configurable?
expired := time.NewTimer(streamCreationTimeout)
expired := time.NewTimer(defaultStreamCreationTimeout)
var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream
receivedStreams := 0
@ -612,6 +613,15 @@ func getPodCoordinates(request *restful.Request) (namespace, pod string, uid typ
return
}
// PortForwarder knows how to forward content from a data stream to/from a port
// in a pod.
type PortForwarder interface {
// PortForwarder copies data between a data stream and a port in a pod.
PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error
}
// getPortForward handles a new restful port forward request. It determines the
// pod name and uid and then calls ServePortForward.
func (s *Server) getPortForward(request *restful.Request, response *restful.Response) {
podNamespace, podID, uid := getPodCoordinates(request)
pod, ok := s.host.GetPodByName(podNamespace, podID)
@ -620,80 +630,280 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp
return
}
podName := kubecontainer.GetPodFullName(pod)
ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), defaultStreamCreationTimeout)
}
// ServePortForward handles a port forwarding request. A single request is
// kept alive as long as the client is still alive and the connection has not
// been timed out due to idleness. This function handles multiple forwarded
// connections; i.e., multiple `curl http://localhost:8888/` requests will be
// handled by a single invocation of ServePortForward.
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) {
streamChan := make(chan httpstream.Stream, 1)
glog.V(5).Infof("Upgrading port forward response")
upgrader := spdy.NewResponseUpgrader()
conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream) error {
portString := stream.Headers().Get(api.PortHeader)
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return fmt.Errorf("Unable to parse '%s' as a port: %v", portString, err)
}
if port < 1 {
return fmt.Errorf("Port '%d' must be greater than 0", port)
}
streamChan <- stream
return nil
})
conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan))
if conn == nil {
return
}
defer conn.Close()
conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout())
var dataStreamLock sync.Mutex
dataStreamChans := make(map[string]chan httpstream.Stream)
glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
conn.SetIdleTimeout(idleTimeout)
h := &portForwardStreamHandler{
conn: conn,
streamChan: streamChan,
streamPairs: make(map[string]*portForwardStreamPair),
streamCreationTimeout: streamCreationTimeout,
pod: podName,
uid: uid,
forwarder: portForwarder,
}
h.run()
}
// portForwardStreamReceived is the httpstream.NewStreamHandler for port
// forward streams. It checks each stream's port and stream type headers,
// rejecting any streams that with missing or invalid values. Each valid
// stream is sent to the streams channel.
func portForwardStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream) error {
return func(stream httpstream.Stream) error {
// make sure it has a valid port header
portString := stream.Headers().Get(api.PortHeader)
if len(portString) == 0 {
return fmt.Errorf("%q header is required", api.PortHeader)
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return fmt.Errorf("unable to parse %q as a port: %v", portString, err)
}
if port < 1 {
return fmt.Errorf("port %q must be > 0", portString)
}
// make sure it has a valid stream type header
streamType := stream.Headers().Get(api.StreamType)
if len(streamType) == 0 {
return fmt.Errorf("%q header is required", api.StreamType)
}
if streamType != api.StreamTypeError && streamType != api.StreamTypeData {
return fmt.Errorf("invalid stream type %q", streamType)
}
streams <- stream
return nil
}
}
// portForwardStreamHandler is capable of processing multiple port forward
// requests over a single httpstream.Connection.
type portForwardStreamHandler struct {
conn httpstream.Connection
streamChan chan httpstream.Stream
streamPairsLock sync.RWMutex
streamPairs map[string]*portForwardStreamPair
streamCreationTimeout time.Duration
pod string
uid types.UID
forwarder PortForwarder
}
// getStreamPair returns a portForwardStreamPair for requestID. This creates a
// new pair if one does not yet exist for the requestID. The returned bool is
// true if the pair was created.
func (h *portForwardStreamHandler) getStreamPair(requestID string) (*portForwardStreamPair, bool) {
h.streamPairsLock.Lock()
defer h.streamPairsLock.Unlock()
if p, ok := h.streamPairs[requestID]; ok {
glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID)
return p, false
}
glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID)
p := newPortForwardPair(requestID)
h.streamPairs[requestID] = p
return p, true
}
// monitorStreamPair waits for the pair to receive both its error and data
// streams, or for the timeout to expire (whichever happens first), and then
// removes the pair.
func (h *portForwardStreamHandler) monitorStreamPair(p *portForwardStreamPair, timeout <-chan time.Time) {
select {
case <-timeout:
err := fmt.Errorf("(conn=%p, request=%s) timed out waiting for streams", h.conn, p.requestID)
util.HandleError(err)
p.printError(err.Error())
case <-p.complete:
glog.V(5).Infof("(conn=%p, request=%s) successfully received error and data streams", h.conn, p.requestID)
}
h.removeStreamPair(p.requestID)
}
// hasStreamPair returns a bool indicating if a stream pair for requestID
// exists.
func (h *portForwardStreamHandler) hasStreamPair(requestID string) bool {
h.streamPairsLock.RLock()
defer h.streamPairsLock.RUnlock()
_, ok := h.streamPairs[requestID]
return ok
}
// removeStreamPair removes the stream pair identified by requestID from streamPairs.
func (h *portForwardStreamHandler) removeStreamPair(requestID string) {
h.streamPairsLock.Lock()
defer h.streamPairsLock.Unlock()
delete(h.streamPairs, requestID)
}
// requestID returns the request id for stream.
func (h *portForwardStreamHandler) requestID(stream httpstream.Stream) string {
requestID := stream.Headers().Get(api.PortForwardRequestIDHeader)
if len(requestID) == 0 {
glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader)
// If we get here, it's because the connection came from an older client
// that isn't generating the request id header
// (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287)
//
// This is a best-effort attempt at supporting older clients.
//
// When there aren't concurrent new forwarded connections, each connection
// will have a pair of streams (data, error), and the stream IDs will be
// consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert
// the stream ID into a pseudo-request id by taking the stream type and
// using id = stream.Identifier() when the stream type is error,
// and id = stream.Identifier() - 2 when it's data.
//
// NOTE: this only works when there are not concurrent new streams from
// multiple forwarded connections; it's a best-effort attempt at supporting
// old clients that don't generate request ids. If there are concurrent
// new connections, it's possible that 1 connection gets streams whose IDs
// are not consecutive (e.g. 5 and 9 instead of 5 and 7).
streamType := stream.Headers().Get(api.StreamType)
switch streamType {
case api.StreamTypeError:
requestID = strconv.Itoa(int(stream.Identifier()))
case api.StreamTypeData:
requestID = strconv.Itoa(int(stream.Identifier()) - 2)
}
glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier())
}
return requestID
}
// run is the main loop for the portForwardStreamHandler. It processes new
// streams, invoking portForward for each complete stream pair. The loop exits
// when the httpstream.Connection is closed.
func (h *portForwardStreamHandler) run() {
glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn)
Loop:
for {
select {
case <-conn.CloseChan():
case <-h.conn.CloseChan():
glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn)
break Loop
case stream := <-streamChan:
case stream := <-h.streamChan:
requestID := h.requestID(stream)
streamType := stream.Headers().Get(api.StreamType)
port := stream.Headers().Get(api.PortHeader)
dataStreamLock.Lock()
switch streamType {
case "error":
ch := make(chan httpstream.Stream)
dataStreamChans[port] = ch
go waitForPortForwardDataStreamAndRun(kubecontainer.GetPodFullName(pod), uid, stream, ch, s.host)
case "data":
ch, ok := dataStreamChans[port]
if ok {
ch <- stream
delete(dataStreamChans, port)
} else {
glog.Errorf("Unable to locate data stream channel for port %s", port)
}
default:
glog.Errorf("streamType header must be 'error' or 'data', got: '%s'", streamType)
stream.Reset()
glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType)
p, created := h.getStreamPair(requestID)
if created {
go h.monitorStreamPair(p, time.After(h.streamCreationTimeout))
}
if complete, err := p.add(stream); err != nil {
msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err)
util.HandleError(errors.New(msg))
p.printError(msg)
} else if complete {
go h.portForward(p)
}
dataStreamLock.Unlock()
}
}
}
func waitForPortForwardDataStreamAndRun(pod string, uid types.UID, errorStream httpstream.Stream, dataStreamChan chan httpstream.Stream, host HostInterface) {
defer errorStream.Reset()
// portForward invokes the portForwardStreamHandler's forwarder.PortForward
// function for the given stream pair.
func (h *portForwardStreamHandler) portForward(p *portForwardStreamPair) {
defer p.dataStream.Close()
defer p.errorStream.Close()
var dataStream httpstream.Stream
portString := p.dataStream.Headers().Get(api.PortHeader)
port, _ := strconv.ParseUint(portString, 10, 16)
select {
case dataStream = <-dataStreamChan:
case <-time.After(streamCreationTimeout):
errorStream.Write([]byte("Timed out waiting for data stream"))
//TODO delete from dataStreamChans[port]
return
glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
err := h.forwarder.PortForward(h.pod, h.uid, uint16(port), p.dataStream)
glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
if err != nil {
msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err)
util.HandleError(msg)
fmt.Fprint(p.errorStream, msg.Error())
}
}
// portForwardStreamPair represents the error and data streams for a port
// forwarding request.
type portForwardStreamPair struct {
lock sync.RWMutex
requestID string
dataStream httpstream.Stream
errorStream httpstream.Stream
complete chan struct{}
}
// newPortForwardPair creates a new portForwardStreamPair.
func newPortForwardPair(requestID string) *portForwardStreamPair {
return &portForwardStreamPair{
requestID: requestID,
complete: make(chan struct{}),
}
}
// add adds the stream to the portForwardStreamPair. If the pair already
// contains a stream for the new stream's type, an error is returned. add
// returns true if both the data and error streams for this pair have been
// received.
func (p *portForwardStreamPair) add(stream httpstream.Stream) (bool, error) {
p.lock.Lock()
defer p.lock.Unlock()
switch stream.Headers().Get(api.StreamType) {
case api.StreamTypeError:
if p.errorStream != nil {
return false, errors.New("error stream already assigned")
}
p.errorStream = stream
case api.StreamTypeData:
if p.dataStream != nil {
return false, errors.New("data stream already assigned")
}
p.dataStream = stream
}
portString := dataStream.Headers().Get(api.PortHeader)
port, _ := strconv.ParseUint(portString, 10, 16)
err := host.PortForward(pod, uid, uint16(port), dataStream)
if err != nil {
msg := fmt.Errorf("Error forwarding port %d to pod %s, uid %v: %v", port, pod, uid, err)
glog.Error(msg)
errorStream.Write([]byte(msg.Error()))
complete := p.errorStream != nil && p.dataStream != nil
if complete {
close(p.complete)
}
return complete, nil
}
// printError writes s to p.errorStream if p.errorStream has been set.
func (p *portForwardStreamPair) printError(s string) {
p.lock.RLock()
defer p.lock.RUnlock()
if p.errorStream != nil {
fmt.Fprint(p.errorStream, s)
}
}

View File

@ -1426,3 +1426,221 @@ func TestServePortForward(t *testing.T) {
<-portForwardFuncDone
}
}
type fakeHttpStream struct {
headers http.Header
id uint32
}
func newFakeHttpStream() *fakeHttpStream {
return &fakeHttpStream{
headers: make(http.Header),
}
}
var _ httpstream.Stream = &fakeHttpStream{}
func (s *fakeHttpStream) Read(data []byte) (int, error) {
return 0, nil
}
func (s *fakeHttpStream) Write(data []byte) (int, error) {
return 0, nil
}
func (s *fakeHttpStream) Close() error {
return nil
}
func (s *fakeHttpStream) Reset() error {
return nil
}
func (s *fakeHttpStream) Headers() http.Header {
return s.headers
}
func (s *fakeHttpStream) Identifier() uint32 {
return s.id
}
func TestPortForwardStreamReceived(t *testing.T) {
tests := map[string]struct {
port string
streamType string
expectedError string
}{
"missing port": {
expectedError: `"port" header is required`,
},
"unable to parse port": {
port: "abc",
expectedError: `unable to parse "abc" as a port: strconv.ParseUint: parsing "abc": invalid syntax`,
},
"negative port": {
port: "-1",
expectedError: `unable to parse "-1" as a port: strconv.ParseUint: parsing "-1": invalid syntax`,
},
"missing stream type": {
port: "80",
expectedError: `"streamType" header is required`,
},
"valid port with error stream": {
port: "80",
streamType: "error",
},
"valid port with data stream": {
port: "80",
streamType: "data",
},
"invalid stream type": {
port: "80",
streamType: "foo",
expectedError: `invalid stream type "foo"`,
},
}
for name, test := range tests {
streams := make(chan httpstream.Stream, 1)
f := portForwardStreamReceived(streams)
stream := newFakeHttpStream()
if len(test.port) > 0 {
stream.headers.Set("port", test.port)
}
if len(test.streamType) > 0 {
stream.headers.Set("streamType", test.streamType)
}
err := f(stream)
if len(test.expectedError) > 0 {
if err == nil {
t.Errorf("%s: expected err=%q, but it was nil", name, test.expectedError)
}
if e, a := test.expectedError, err.Error(); e != a {
t.Errorf("%s: expected err=%q, got %q", name, e, a)
}
continue
}
if err != nil {
t.Errorf("%s: unexpected error %v", name, err)
continue
}
if s := <-streams; s != stream {
t.Errorf("%s: expected stream %#v, got %#v", name, stream, s)
}
}
}
func TestGetStreamPair(t *testing.T) {
timeout := make(chan time.Time)
h := &portForwardStreamHandler{
streamPairs: make(map[string]*portForwardStreamPair),
}
// test adding a new entry
p, created := h.getStreamPair("1")
if p == nil {
t.Fatalf("unexpected nil pair")
}
if !created {
t.Fatal("expected created=true")
}
if p.dataStream != nil {
t.Errorf("unexpected non-nil data stream")
}
if p.errorStream != nil {
t.Errorf("unexpected non-nil error stream")
}
// start the monitor for this pair
monitorDone := make(chan struct{})
go func() {
h.monitorStreamPair(p, timeout)
close(monitorDone)
}()
if !h.hasStreamPair("1") {
t.Fatal("This should still be true")
}
// make sure we can retrieve an existing entry
p2, created := h.getStreamPair("1")
if created {
t.Fatal("expected created=false")
}
if p != p2 {
t.Fatalf("retrieving an existing pair: expected %#v, got %#v", p, p2)
}
// removed via complete
dataStream := newFakeHttpStream()
dataStream.headers.Set(api.StreamType, api.StreamTypeData)
complete, err := p.add(dataStream)
if err != nil {
t.Fatalf("unexpected error adding data stream to pair: %v", err)
}
if complete {
t.Fatalf("unexpected complete")
}
errorStream := newFakeHttpStream()
errorStream.headers.Set(api.StreamType, api.StreamTypeError)
complete, err = p.add(errorStream)
if err != nil {
t.Fatalf("unexpected error adding error stream to pair: %v", err)
}
if !complete {
t.Fatal("unexpected incomplete")
}
// make sure monitorStreamPair completed
<-monitorDone
// make sure the pair was removed
if h.hasStreamPair("1") {
t.Fatal("expected removal of pair after both data and error streams received")
}
// removed via timeout
p, created = h.getStreamPair("2")
if !created {
t.Fatal("expected created=true")
}
if p == nil {
t.Fatal("expected p not to be nil")
}
monitorDone = make(chan struct{})
go func() {
h.monitorStreamPair(p, timeout)
close(monitorDone)
}()
// cause the timeout
close(timeout)
// make sure monitorStreamPair completed
<-monitorDone
if h.hasStreamPair("2") {
t.Fatal("expected stream pair to be removed")
}
}
func TestRequestID(t *testing.T) {
h := &portForwardStreamHandler{}
s := newFakeHttpStream()
s.headers.Set(api.StreamType, api.StreamTypeError)
s.id = 1
if e, a := "1", h.requestID(s); e != a {
t.Errorf("expected %q, got %q", e, a)
}
s.headers.Set(api.StreamType, api.StreamTypeData)
s.id = 3
if e, a := "1", h.requestID(s); e != a {
t.Errorf("expected %q, got %q", e, a)
}
s.id = 7
s.headers.Set(api.PortForwardRequestIDHeader, "2")
if e, a := "2", h.requestID(s); e != a {
t.Errorf("expected %q, got %q", e, a)
}
}

View File

@ -78,6 +78,8 @@ type Stream interface {
Reset() error
// Headers returns the headers used to create the stream.
Headers() http.Header
// Identifier returns the stream's ID.
Identifier() uint32
}
// IsUpgradeRequest returns true if the given request is a connection upgrade request

View File

@ -60,10 +60,7 @@ const (
simplePodPort = 80
)
var (
portForwardRegexp = regexp.MustCompile("Forwarding from 127.0.0.1:([0-9]+) -> 80")
proxyRegexp = regexp.MustCompile("Starting to serve on 127.0.0.1:([0-9]+)")
)
var proxyRegexp = regexp.MustCompile("Starting to serve on 127.0.0.1:([0-9]+)")
var _ = Describe("Kubectl client", func() {
defer GinkgoRecover()
@ -200,32 +197,11 @@ var _ = Describe("Kubectl client", func() {
It("should support port-forward", func() {
By("forwarding the container port to a local port")
cmd := kubectlCmd("port-forward", fmt.Sprintf("--namespace=%v", ns), simplePodName, fmt.Sprintf(":%d", simplePodPort))
cmd, listenPort := runPortForward(ns, simplePodName, simplePodPort)
defer tryKill(cmd)
// This is somewhat ugly but is the only way to retrieve the port that was picked
// by the port-forward command. We don't want to hard code the port as we have no
// way of guaranteeing we can pick one that isn't in use, particularly on Jenkins.
Logf("starting port-forward command and streaming output")
stdout, stderr, err := startCmdAndStreamOutput(cmd)
if err != nil {
Failf("Failed to start port-forward command: %v", err)
}
defer stdout.Close()
defer stderr.Close()
buf := make([]byte, 128)
var n int
Logf("reading from `kubectl port-forward` command's stderr")
if n, err = stderr.Read(buf); err != nil {
Failf("Failed to read from kubectl port-forward stderr: %v", err)
}
portForwardOutput := string(buf[:n])
match := portForwardRegexp.FindStringSubmatch(portForwardOutput)
if len(match) != 2 {
Failf("Failed to parse kubectl port-forward output: %s", portForwardOutput)
}
By("curling local port output")
localAddr := fmt.Sprintf("http://localhost:%s", match[1])
localAddr := fmt.Sprintf("http://localhost:%d", listenPort)
body, err := curl(localAddr)
Logf("got: %s", body)
if err != nil {

234
test/e2e/portforward.go Normal file
View File

@ -0,0 +1,234 @@
/*
Copyright 2015 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package e2e
import (
"fmt"
"io/ioutil"
"net"
"os/exec"
"regexp"
"strconv"
"strings"
"k8s.io/kubernetes/pkg/api"
. "github.com/onsi/ginkgo"
)
const (
podName = "pfpod"
)
// TODO support other ports besides 80
var portForwardRegexp = regexp.MustCompile("Forwarding from 127.0.0.1:([0-9]+) -> 80")
func pfPod(expectedClientData, chunks, chunkSize, chunkIntervalMillis string) *api.Pod {
return &api.Pod{
ObjectMeta: api.ObjectMeta{
Name: podName,
Labels: map[string]string{"name": podName},
},
Spec: api.PodSpec{
Containers: []api.Container{
{
Name: "portforwardtester",
Image: "gcr.io/google_containers/portforwardtester:1.0",
Env: []api.EnvVar{
{
Name: "BIND_PORT",
Value: "80",
},
{
Name: "EXPECTED_CLIENT_DATA",
Value: expectedClientData,
},
{
Name: "CHUNKS",
Value: chunks,
},
{
Name: "CHUNK_SIZE",
Value: chunkSize,
},
{
Name: "CHUNK_INTERVAL",
Value: chunkIntervalMillis,
},
},
},
},
RestartPolicy: api.RestartPolicyNever,
},
}
}
func runPortForward(ns, podName string, port int) (*exec.Cmd, int) {
cmd := kubectlCmd("port-forward", fmt.Sprintf("--namespace=%v", ns), podName, fmt.Sprintf(":%d", port))
// This is somewhat ugly but is the only way to retrieve the port that was picked
// by the port-forward command. We don't want to hard code the port as we have no
// way of guaranteeing we can pick one that isn't in use, particularly on Jenkins.
Logf("starting port-forward command and streaming output")
stdout, stderr, err := startCmdAndStreamOutput(cmd)
if err != nil {
Failf("Failed to start port-forward command: %v", err)
}
defer stdout.Close()
defer stderr.Close()
buf := make([]byte, 128)
var n int
Logf("reading from `kubectl port-forward` command's stderr")
if n, err = stderr.Read(buf); err != nil {
Failf("Failed to read from kubectl port-forward stderr: %v", err)
}
portForwardOutput := string(buf[:n])
match := portForwardRegexp.FindStringSubmatch(portForwardOutput)
if len(match) != 2 {
Failf("Failed to parse kubectl port-forward output: %s", portForwardOutput)
}
listenPort, err := strconv.Atoi(match[1])
if err != nil {
Failf("Error converting %s to an int: %v", match[1], err)
}
return cmd, listenPort
}
var _ = Describe("Port forwarding", func() {
framework := NewFramework("port-forwarding")
Describe("With a server that expects a client request", func() {
It("should support a client that connects, sends no data, and disconnects", func() {
By("creating the target pod")
pod := pfPod("abc", "1", "1", "1")
framework.Client.Pods(framework.Namespace.Name).Create(pod)
framework.WaitForPodRunning(pod.Name)
By("Running 'kubectl port-forward'")
cmd, listenPort := runPortForward(framework.Namespace.Name, pod.Name, 80)
defer tryKill(cmd)
By("Dialing the local port")
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenPort))
if err != nil {
Failf("Couldn't connect to port %d: %v", listenPort, err)
}
By("Closing the connection to the local port")
conn.Close()
logOutput := runKubectl("logs", fmt.Sprintf("--namespace=%v", framework.Namespace.Name), "-f", podName)
verifyLogMessage(logOutput, "Accepted client connection")
verifyLogMessage(logOutput, "Expected to read 3 bytes from client, but got 0 instead")
})
It("should support a client that connects, sends data, and disconnects", func() {
By("creating the target pod")
pod := pfPod("abc", "10", "10", "100")
framework.Client.Pods(framework.Namespace.Name).Create(pod)
framework.WaitForPodRunning(pod.Name)
By("Running 'kubectl port-forward'")
cmd, listenPort := runPortForward(framework.Namespace.Name, pod.Name, 80)
defer tryKill(cmd)
By("Dialing the local port")
addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", listenPort))
if err != nil {
Failf("Error resolving tcp addr: %v", err)
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
Failf("Couldn't connect to port %d: %v", listenPort, err)
}
defer func() {
By("Closing the connection to the local port")
conn.Close()
}()
By("Sending the expected data to the local port")
fmt.Fprint(conn, "abc")
By("Closing the write half of the client's connection")
conn.CloseWrite()
By("Reading data from the local port")
fromServer, err := ioutil.ReadAll(conn)
if err != nil {
Failf("Unexpected error reading data from the server: %v", err)
}
if e, a := strings.Repeat("x", 100), string(fromServer); e != a {
Failf("Expected %q from server, got %q", e, a)
}
logOutput := runKubectl("logs", fmt.Sprintf("--namespace=%v", framework.Namespace.Name), "-f", podName)
verifyLogMessage(logOutput, "^Accepted client connection$")
verifyLogMessage(logOutput, "^Received expected client data$")
verifyLogMessage(logOutput, "^Done$")
})
})
Describe("With a server that expects no client request", func() {
It("should support a client that connects, sends no data, and disconnects", func() {
By("creating the target pod")
pod := pfPod("", "10", "10", "100")
framework.Client.Pods(framework.Namespace.Name).Create(pod)
framework.WaitForPodRunning(pod.Name)
By("Running 'kubectl port-forward'")
cmd, listenPort := runPortForward(framework.Namespace.Name, pod.Name, 80)
defer tryKill(cmd)
By("Dialing the local port")
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenPort))
if err != nil {
Failf("Couldn't connect to port %d: %v", listenPort, err)
}
defer func() {
By("Closing the connection to the local port")
conn.Close()
}()
By("Reading data from the local port")
fromServer, err := ioutil.ReadAll(conn)
if err != nil {
Failf("Unexpected error reading data from the server: %v", err)
}
if e, a := strings.Repeat("x", 100), string(fromServer); e != a {
Failf("Expected %q from server, got %q", e, a)
}
logOutput := runKubectl("logs", fmt.Sprintf("--namespace=%v", framework.Namespace.Name), "-f", podName)
verifyLogMessage(logOutput, "Accepted client connection")
verifyLogMessage(logOutput, "Done")
})
})
})
func verifyLogMessage(log, expected string) {
re := regexp.MustCompile(expected)
lines := strings.Split(log, "\n")
for i := range lines {
if re.MatchString(lines[i]) {
return
}
}
Failf("Missing %q from log: %s", expected, log)
}