diff --git a/pkg/client/unversioned/portforward/portforward.go b/pkg/client/unversioned/portforward/portforward.go index 6ecedb6df9..d746125d0f 100644 --- a/pkg/client/unversioned/portforward/portforward.go +++ b/pkg/client/unversioned/portforward/portforward.go @@ -108,7 +108,7 @@ func parsePorts(ports []string) ([]ForwardedPort, error) { } // New creates a new PortForwarder. -func New(dialer httpstream.Dialer, ports []string, stopChan <-chan struct{}, out, errOut io.Writer) (*PortForwarder, error) { +func New(dialer httpstream.Dialer, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) { if len(ports) == 0 { return nil, errors.New("You must specify at least 1 port") } @@ -120,7 +120,7 @@ func New(dialer httpstream.Dialer, ports []string, stopChan <-chan struct{}, out dialer: dialer, ports: parsedPorts, stopChan: stopChan, - Ready: make(chan struct{}), + Ready: readyChan, out: out, errOut: errOut, }, nil @@ -164,7 +164,9 @@ func (pf *PortForwarder) forward() error { return fmt.Errorf("Unable to listen on any of the requested ports: %v", pf.ports) } - close(pf.Ready) + if pf.Ready != nil { + close(pf.Ready) + } // wait for interrupt or conn closure select { diff --git a/pkg/client/unversioned/portforward/portforward_test.go b/pkg/client/unversioned/portforward/portforward_test.go index 6bc49a902f..4bdd222c08 100644 --- a/pkg/client/unversioned/portforward/portforward_test.go +++ b/pkg/client/unversioned/portforward/portforward_test.go @@ -88,7 +88,8 @@ func TestParsePortsAndNew(t *testing.T) { dialer := &fakeDialer{} expectedStopChan := make(chan struct{}) - pf, err := New(dialer, test.input, expectedStopChan, os.Stdout, os.Stderr) + readyChan := make(chan struct{}) + pf, err := New(dialer, test.input, expectedStopChan, readyChan, os.Stdout, os.Stderr) haveError = err != nil if e, a := test.expectNewError, haveError; e != a { t.Fatalf("%d: New: error expected=%t, got %t: %s", i, e, a, err) @@ -305,8 +306,9 @@ func TestForwardPorts(t *testing.T) { } stopChan := make(chan struct{}, 1) + readyChan := make(chan struct{}) - pf, err := New(exec, test.ports, stopChan, os.Stdout, os.Stderr) + pf, err := New(exec, test.ports, stopChan, readyChan, os.Stdout, os.Stderr) if err != nil { t.Fatalf("%s: unexpected error calling New: %v", testName, err) } @@ -375,8 +377,9 @@ func TestForwardPortsReturnsErrorWhenAllBindsFailed(t *testing.T) { stopChan1 := make(chan struct{}, 1) defer close(stopChan1) + readyChan1 := make(chan struct{}) - pf1, err := New(exec, []string{"5555"}, stopChan1, os.Stdout, os.Stderr) + pf1, err := New(exec, []string{"5555"}, stopChan1, readyChan1, os.Stdout, os.Stderr) if err != nil { t.Fatalf("error creating pf1: %v", err) } @@ -384,7 +387,8 @@ func TestForwardPortsReturnsErrorWhenAllBindsFailed(t *testing.T) { <-pf1.Ready stopChan2 := make(chan struct{}, 1) - pf2, err := New(exec, []string{"5555"}, stopChan2, os.Stdout, os.Stderr) + readyChan2 := make(chan struct{}) + pf2, err := New(exec, []string{"5555"}, stopChan2, readyChan2, os.Stdout, os.Stderr) if err != nil { t.Fatalf("error creating pf2: %v", err) } diff --git a/pkg/kubectl/cmd/portforward.go b/pkg/kubectl/cmd/portforward.go index 2014b62f40..fae3758474 100644 --- a/pkg/kubectl/cmd/portforward.go +++ b/pkg/kubectl/cmd/portforward.go @@ -41,6 +41,8 @@ type PortForwardOptions struct { Client *client.Client Ports []string PortForwarder portForwarder + StopChannel chan struct{} + ReadyChannel chan struct{} } var ( @@ -88,19 +90,19 @@ func NewCmdPortForward(f *cmdutil.Factory, cmdOut, cmdErr io.Writer) *cobra.Comm } type portForwarder interface { - ForwardPorts(method string, url *url.URL, config *restclient.Config, ports []string, stopChan <-chan struct{}) error + ForwardPorts(method string, url *url.URL, opts PortForwardOptions) error } type defaultPortForwarder struct { cmdOut, cmdErr io.Writer } -func (f *defaultPortForwarder) ForwardPorts(method string, url *url.URL, config *restclient.Config, ports []string, stopChan <-chan struct{}) error { - dialer, err := remotecommand.NewExecutor(config, method, url) +func (f *defaultPortForwarder) ForwardPorts(method string, url *url.URL, opts PortForwardOptions) error { + dialer, err := remotecommand.NewExecutor(opts.Config, method, url) if err != nil { return err } - fw, err := portforward.New(dialer, ports, stopChan, f.cmdOut, f.cmdErr) + fw, err := portforward.New(dialer, opts.Ports, opts.StopChannel, opts.ReadyChannel, f.cmdOut, f.cmdErr) if err != nil { return err } @@ -138,6 +140,8 @@ func (o *PortForwardOptions) Complete(f *cmdutil.Factory, cmd *cobra.Command, ar return err } + o.StopChannel = make(chan struct{}, 1) + o.ReadyChannel = make(chan struct{}) return nil } @@ -165,17 +169,18 @@ func (o PortForwardOptions) RunPortForward() error { } if pod.Status.Phase != api.PodRunning { - return fmt.Errorf("Unable to execute command because pod is not running. Current status=%v", pod.Status.Phase) + return fmt.Errorf("unable to forward port because pod is not running. Current status=%v", pod.Status.Phase) } signals := make(chan os.Signal, 1) signal.Notify(signals, os.Interrupt) defer signal.Stop(signals) - stopCh := make(chan struct{}, 1) go func() { <-signals - close(stopCh) + if o.StopChannel != nil { + close(o.StopChannel) + } }() req := o.Client.RESTClient.Post(). @@ -184,5 +189,5 @@ func (o PortForwardOptions) RunPortForward() error { Name(pod.Name). SubResource("portforward") - return o.PortForwarder.ForwardPorts("POST", req.URL(), o.Config, o.Ports, stopCh) + return o.PortForwarder.ForwardPorts("POST", req.URL(), o) } diff --git a/pkg/kubectl/cmd/portforward_test.go b/pkg/kubectl/cmd/portforward_test.go index 888c32b96c..45c4597367 100644 --- a/pkg/kubectl/cmd/portforward_test.go +++ b/pkg/kubectl/cmd/portforward_test.go @@ -38,7 +38,7 @@ type fakePortForwarder struct { pfErr error } -func (f *fakePortForwarder) ForwardPorts(method string, url *url.URL, config *restclient.Config, ports []string, stopChan <-chan struct{}) error { +func (f *fakePortForwarder) ForwardPorts(method string, url *url.URL, opts PortForwardOptions) error { f.method = method f.url = url return f.pfErr @@ -108,7 +108,7 @@ func TestPortForward(t *testing.T) { cmd.Run(cmd, []string{"foo", ":5000", ":1000"}) if test.pfErr && err != ff.pfErr { - t.Errorf("%s: Unexpected exec error: %v", test.name, err) + t.Errorf("%s: Unexpected port-forward error: %v", test.name, err) } if !test.pfErr && err != nil { t.Errorf("%s: Unexpected error: %v", test.name, err) @@ -191,7 +191,7 @@ func TestPortForwardWithPFlag(t *testing.T) { cmd.Run(cmd, []string{":5000", ":1000"}) if test.pfErr && err != ff.pfErr { - t.Errorf("%s: Unexpected exec error: %v", test.name, err) + t.Errorf("%s: Unexpected port-forward error: %v", test.name, err) } if !test.pfErr && err != nil { t.Errorf("%s: Unexpected error: %v", test.name, err)