diff --git a/pkg/client/unversioned/remotecommand/remotecommand.go b/pkg/client/unversioned/remotecommand/remotecommand.go index 277532e676..e63af79aa0 100644 --- a/pkg/client/unversioned/remotecommand/remotecommand.go +++ b/pkg/client/unversioned/remotecommand/remotecommand.go @@ -21,12 +21,13 @@ import ( "io" "io/ioutil" "net/http" + "sync" - "github.com/golang/glog" "k8s.io/kubernetes/pkg/api" client "k8s.io/kubernetes/pkg/client/unversioned" "k8s.io/kubernetes/pkg/conversion/queryparams" "k8s.io/kubernetes/pkg/runtime" + "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/httpstream" "k8s.io/kubernetes/pkg/util/httpstream/spdy" ) @@ -155,90 +156,110 @@ func (e *Streamer) doStream() error { } defer conn.Close() - doneChan := make(chan struct{}, 2) - errorChan := make(chan error) - - cp := func(s string, dst io.Writer, src io.Reader) { - glog.V(4).Infof("Copying %s", s) - defer glog.V(4).Infof("Done copying %s", s) - if _, err := io.Copy(dst, src); err != nil && err != io.EOF { - glog.Errorf("Error copying %s: %v", s, err) - } - if s == api.StreamTypeStdout || s == api.StreamTypeStderr { - doneChan <- struct{}{} - } - } - headers := http.Header{} + + // set up error stream + errorChan := make(chan error) headers.Set(api.StreamType, api.StreamTypeError) errorStream, err := conn.CreateStream(headers) if err != nil { return err } + go func() { message, err := ioutil.ReadAll(errorStream) - if err != nil && err != io.EOF { - errorChan <- fmt.Errorf("Error reading from error stream: %s", err) - return - } - if len(message) > 0 { - errorChan <- fmt.Errorf("Error executing remote command: %s", message) - return + switch { + case err != nil && err != io.EOF: + errorChan <- fmt.Errorf("error reading from error stream: %s", err) + case len(message) > 0: + errorChan <- fmt.Errorf("error executing remote command: %s", message) + default: + errorChan <- nil } + close(errorChan) }() - defer errorStream.Reset() + var wg sync.WaitGroup + var once sync.Once + + // set up stdin stream if e.stdin != nil { headers.Set(api.StreamType, api.StreamTypeStdin) remoteStdin, err := conn.CreateStream(headers) if err != nil { return err } - defer remoteStdin.Reset() - // TODO this goroutine will never exit cleanly (the io.Copy never unblocks) - // because stdin is not closed until the process exits. If we try to call - // stdin.Close(), it returns no error but doesn't unblock the copy. It will - // exit when the process exits, instead. - go cp(api.StreamTypeStdin, remoteStdin, e.stdin) + + // copy from client's stdin to container's stdin + go func() { + // if e.stdin is noninteractive, e.g. `echo abc | kubectl exec -i -- cat`, make sure + // we close remoteStdin as soon as the copy from e.stdin to remoteStdin finishes. Otherwise + // the executed command will remain running. + defer once.Do(func() { remoteStdin.Close() }) + + if _, err := io.Copy(remoteStdin, e.stdin); err != nil { + util.HandleError(err) + } + }() + + // read from remoteStdin until the stream is closed. this is essential to + // be able to exit interactive sessions cleanly and not leak goroutines or + // hang the client's terminal. + // + // go-dockerclient's current hijack implementation + // (https://github.com/fsouza/go-dockerclient/blob/89f3d56d93788dfe85f864a44f85d9738fca0670/client.go#L564) + // waits for all three streams (stdin/stdout/stderr) to finish copying + // before returning. When hijack finishes copying stdout/stderr, it calls + // Close() on its side of remoteStdin, which allows this copy to complete. + // When that happens, we must Close() on our side of remoteStdin, to + // allow the copy in hijack to complete, and hijack to return. + go func() { + defer once.Do(func() { remoteStdin.Close() }) + // this "copy" doesn't actually read anything - it's just here to wait for + // the server to close remoteStdin. + if _, err := io.Copy(ioutil.Discard, remoteStdin); err != nil { + util.HandleError(err) + } + }() } - waitCount := 0 - completedStreams := 0 - + // set up stdout stream if e.stdout != nil { - waitCount++ headers.Set(api.StreamType, api.StreamTypeStdout) remoteStdout, err := conn.CreateStream(headers) if err != nil { return err } - defer remoteStdout.Reset() - go cp(api.StreamTypeStdout, e.stdout, remoteStdout) + + wg.Add(1) + go func() { + defer wg.Done() + if _, err := io.Copy(e.stdout, remoteStdout); err != nil { + util.HandleError(err) + } + }() } + // set up stderr stream if e.stderr != nil && !e.tty { - waitCount++ headers.Set(api.StreamType, api.StreamTypeStderr) remoteStderr, err := conn.CreateStream(headers) if err != nil { return err } - defer remoteStderr.Reset() - go cp(api.StreamTypeStderr, e.stderr, remoteStderr) - } -Loop: - for { - select { - case <-doneChan: - completedStreams++ - if completedStreams == waitCount { - break Loop + wg.Add(1) + go func() { + defer wg.Done() + if _, err := io.Copy(e.stderr, remoteStderr); err != nil { + util.HandleError(err) } - case err := <-errorChan: - return err - } + }() } - return nil + // we're waiting for stdout/stderr to finish copying + wg.Wait() + + // waits for errorStream to finish reading with an error or nil + return <-errorChan } diff --git a/pkg/client/unversioned/remotecommand/remotecommand_test.go b/pkg/client/unversioned/remotecommand/remotecommand_test.go index a07a27a29a..9870ee259b 100644 --- a/pkg/client/unversioned/remotecommand/remotecommand_test.go +++ b/pkg/client/unversioned/remotecommand/remotecommand_test.go @@ -19,7 +19,7 @@ package remotecommand import ( "bytes" "fmt" - "io/ioutil" + "io" "net/http" "net/http/httptest" "net/url" @@ -32,7 +32,7 @@ import ( "k8s.io/kubernetes/pkg/util/httpstream/spdy" ) -func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, errorData string, tty bool) http.HandlerFunc { +func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, errorData string, tty bool, messageCount int) http.HandlerFunc { // error + stdin + stdout expectedStreams := 3 if !tty { @@ -70,7 +70,6 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro receivedStreams++ case api.StreamTypeStdin: stdinStream = stream - stdinStream.Close() receivedStreams++ case api.StreamTypeStdout: stdoutStream = stream @@ -82,8 +81,6 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro t.Errorf("%d: unexpected stream type: %q", i, streamType) } - defer stream.Reset() - if receivedStreams == expectedStreams { break WaitForStreams } @@ -91,37 +88,67 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro } if len(errorData) > 0 { - fmt.Fprint(errorStream, errorData) + n, err := fmt.Fprint(errorStream, errorData) + if err != nil { + t.Errorf("%d: error writing to errorStream: %v", i, err) + } + if e, a := len(errorData), n; e != a { + t.Errorf("%d: expected to write %d bytes to errorStream, but only wrote %d", i, e, a) + } errorStream.Close() } if len(stdoutData) > 0 { - fmt.Fprint(stdoutStream, stdoutData) + for j := 0; j < messageCount; j++ { + n, err := fmt.Fprint(stdoutStream, stdoutData) + if err != nil { + t.Errorf("%d: error writing to stdoutStream: %v", i, err) + } + if e, a := len(stdoutData), n; e != a { + t.Errorf("%d: expected to write %d bytes to stdoutStream, but only wrote %d", i, e, a) + } + } stdoutStream.Close() } if len(stderrData) > 0 { - fmt.Fprint(stderrStream, stderrData) + for j := 0; j < messageCount; j++ { + n, err := fmt.Fprint(stderrStream, stderrData) + if err != nil { + t.Errorf("%d: error writing to stderrStream: %v", i, err) + } + if e, a := len(stderrData), n; e != a { + t.Errorf("%d: expected to write %d bytes to stderrStream, but only wrote %d", i, e, a) + } + } stderrStream.Close() } if len(stdinData) > 0 { - data, err := ioutil.ReadAll(stdinStream) - if err != nil { - t.Errorf("%d: error reading stdin stream: %v", i, err) - } - if e, a := stdinData, string(data); e != a { - t.Errorf("%d: stdin: expected %q, got %q", i, e, a) + data := make([]byte, len(stdinData)) + for j := 0; j < messageCount; j++ { + n, err := io.ReadFull(stdinStream, data) + if err != nil { + t.Errorf("%d: error reading stdin stream: %v", i, err) + } + if e, a := len(stdinData), n; e != a { + t.Errorf("%d: expected to read %d bytes from stdinStream, but only read %d", i, e, a) + } + if e, a := stdinData, string(data); e != a { + t.Errorf("%d: stdin: expected %q, got %q", i, e, a) + } } + stdinStream.Close() } }) } func TestRequestExecuteRemoteCommand(t *testing.T) { testCases := []struct { - Stdin string - Stdout string - Stderr string - Error string - Tty bool + Stdin string + Stdout string + Stderr string + Error string + Tty bool + MessageCount int }{ { Error: "bail", @@ -130,6 +157,15 @@ func TestRequestExecuteRemoteCommand(t *testing.T) { Stdin: "a", Stdout: "b", Stderr: "c", + // TODO bump this to a larger number such as 100 once + // https://github.com/docker/spdystream/issues/55 is fixed and the Godep + // is bumped. Sending multiple messages over stdin/stdout/stderr results + // in more frames being spread across multiple spdystream frame workers. + // This makes it more likely that the spdystream bug will be encountered, + // where streams are closed as soon as a goaway frame is received, and + // any pending frames that haven't been processed yet may not be + // delivered (it's a race). + MessageCount: 1, }, { Stdin: "a", @@ -142,7 +178,7 @@ func TestRequestExecuteRemoteCommand(t *testing.T) { localOut := &bytes.Buffer{} localErr := &bytes.Buffer{} - server := httptest.NewServer(fakeExecServer(t, i, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty)) + server := httptest.NewServer(fakeExecServer(t, i, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, testCase.MessageCount)) url, _ := url.ParseRequestURI(server.URL) c := client.NewRESTClient(url, "x", nil, -1, -1) @@ -151,8 +187,7 @@ func TestRequestExecuteRemoteCommand(t *testing.T) { conf := &client.Config{ Host: server.URL, } - e := New(req, conf, []string{"ls", "/"}, strings.NewReader(testCase.Stdin), localOut, localErr, testCase.Tty) - //e.upgrader = testCase.Upgrader + e := New(req, conf, []string{"ls", "/"}, strings.NewReader(strings.Repeat(testCase.Stdin, testCase.MessageCount)), localOut, localErr, testCase.Tty) err := e.Execute() hasErr := err != nil @@ -176,13 +211,13 @@ func TestRequestExecuteRemoteCommand(t *testing.T) { } if len(testCase.Stdout) > 0 { - if e, a := testCase.Stdout, localOut; e != a.String() { + if e, a := strings.Repeat(testCase.Stdout, testCase.MessageCount), localOut; e != a.String() { t.Errorf("%d: expected stdout data '%s', got '%s'", i, e, a) } } if testCase.Stderr != "" { - if e, a := testCase.Stderr, localErr; e != a.String() { + if e, a := strings.Repeat(testCase.Stderr, testCase.MessageCount), localErr; e != a.String() { t.Errorf("%d: expected stderr data '%s', got '%s'", i, e, a) } } @@ -219,7 +254,7 @@ func TestRequestAttachRemoteCommand(t *testing.T) { localOut := &bytes.Buffer{} localErr := &bytes.Buffer{} - server := httptest.NewServer(fakeExecServer(t, i, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty)) + server := httptest.NewServer(fakeExecServer(t, i, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, 1)) url, _ := url.ParseRequestURI(server.URL) c := client.NewRESTClient(url, "x", nil, -1, -1) @@ -229,7 +264,6 @@ func TestRequestAttachRemoteCommand(t *testing.T) { Host: server.URL, } e := NewAttach(req, conf, strings.NewReader(testCase.Stdin), localOut, localErr, testCase.Tty) - //e.upgrader = testCase.Upgrader err := e.Execute() hasErr := err != nil diff --git a/pkg/kubelet/server.go b/pkg/kubelet/server.go index d60376caeb..41f7dd90ab 100644 --- a/pkg/kubelet/server.go +++ b/pkg/kubelet/server.go @@ -543,7 +543,6 @@ WaitForStreams: switch streamType { case api.StreamTypeError: errorStream = stream - defer errorStream.Reset() receivedStreams++ case api.StreamTypeStdin: stdinStream = stream @@ -568,11 +567,6 @@ WaitForStreams: } } - if stdinStream != nil { - // close our half of the input stream, since we won't be writing to it - stdinStream.Close() - } - return stdinStream, stdoutStream, stderrStream, errorStream, conn, tty, true } diff --git a/pkg/util/httpstream/spdy/connection.go b/pkg/util/httpstream/spdy/connection.go index 7c2227917d..6d4855d195 100644 --- a/pkg/util/httpstream/spdy/connection.go +++ b/pkg/util/httpstream/spdy/connection.go @@ -78,7 +78,7 @@ const createStreamResponseTimeout = 30 * time.Second func (c *connection) Close() error { c.streamLock.Lock() for _, s := range c.streams { - s.Reset() + s.Close() } c.streams = make([]httpstream.Stream, 0) c.streamLock.Unlock() diff --git a/test/e2e/kubectl.go b/test/e2e/kubectl.go index 82c649e7d3..1bd0b04ea6 100644 --- a/test/e2e/kubectl.go +++ b/test/e2e/kubectl.go @@ -20,6 +20,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "io/ioutil" "net" "net/http" @@ -158,11 +159,35 @@ var _ = Describe("Kubectl client", func() { It("should support exec", func() { By("executing a command in the container") execOutput := runKubectl("exec", fmt.Sprintf("--namespace=%v", ns), simplePodName, "echo", "running", "in", "container") - expectedExecOutput := "running in container" - if execOutput != expectedExecOutput { - Failf("Unexpected kubectl exec output. Wanted '%s', got '%s'", execOutput, expectedExecOutput) + if e, a := "running in container", execOutput; e != a { + Failf("Unexpected kubectl exec output. Wanted %q, got %q", e, a) + } + + By("executing a command in the container with noninteractive stdin") + execOutput = newKubectlCommand("exec", fmt.Sprintf("--namespace=%v", ns), "-i", simplePodName, "cat"). + withStdinData("abcd1234"). + exec() + if e, a := "abcd1234", execOutput; e != a { + Failf("Unexpected kubectl exec output. Wanted %q, got %q", e, a) + } + + // pretend that we're a user in an interactive shell + r, c, err := newBlockingReader("echo hi\nexit\n") + if err != nil { + Failf("Error creating blocking reader: %v", err) + } + // NOTE this is solely for test cleanup! + defer c.Close() + + By("executing a command in the container with pseudo-interactive stdin") + execOutput = newKubectlCommand("exec", fmt.Sprintf("--namespace=%v", ns), "-i", simplePodName, "bash"). + withStdinReader(r). + exec() + if e, a := "hi", execOutput; e != a { + Failf("Unexpected kubectl exec output. Wanted %q, got %q", e, a) } }) + It("should support port-forward", func() { By("forwarding the container port to a local port") cmd := kubectlCmd("port-forward", fmt.Sprintf("--namespace=%v", ns), simplePodName, fmt.Sprintf(":%d", simplePodPort)) @@ -791,3 +816,20 @@ func getUDData(jpgExpected string, ns string) func(*client.Client, string) error } } } + +// newBlockingReader returns a reader that allows reading the given string, +// then blocks until Close() is called on the returned closer. +// +// We're explicitly returning the reader and closer separately, because +// the closer needs to be the *os.File we get from os.Pipe(). This is required +// so the exec of kubectl can pass the underlying file descriptor to the exec +// syscall, instead of creating another os.Pipe and blocking on the io.Copy +// between the source (e.g. stdin) and the write half of the pipe. +func newBlockingReader(s string) (io.Reader, io.Closer, error) { + r, w, err := os.Pipe() + if err != nil { + return nil, nil, err + } + w.Write([]byte(s)) + return r, w, nil +} diff --git a/test/e2e/util.go b/test/e2e/util.go index c92dfe0190..02ffdb4c51 100644 --- a/test/e2e/util.go +++ b/test/e2e/util.go @@ -977,6 +977,11 @@ func (b kubectlBuilder) withStdinData(data string) *kubectlBuilder { return &b } +func (b kubectlBuilder) withStdinReader(reader io.Reader) *kubectlBuilder { + b.cmd.Stdin = reader + return &b +} + func (b kubectlBuilder) exec() string { var stdout, stderr bytes.Buffer cmd := b.cmd