diff --git a/pkg/daemons/control/proxy/proxy.go b/pkg/daemons/control/proxy/proxy.go index 426d3e81db..455534302d 100644 --- a/pkg/daemons/control/proxy/proxy.go +++ b/pkg/daemons/control/proxy/proxy.go @@ -2,7 +2,6 @@ package proxy import ( "io" - "net" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -14,7 +13,7 @@ type proxy struct { errc chan error } -func Proxy(lconn, rconn net.Conn) error { +func Proxy(lconn, rconn io.ReadWriteCloser) error { p := &proxy{ lconn: lconn, rconn: rconn, diff --git a/pkg/daemons/control/tunnel.go b/pkg/daemons/control/tunnel.go index aafb908ded..9b149cdd95 100644 --- a/pkg/daemons/control/tunnel.go +++ b/pkg/daemons/control/tunnel.go @@ -1,8 +1,10 @@ package control import ( + "bufio" "context" "fmt" + "io" "net" "net/http" "strings" @@ -188,7 +190,7 @@ func (t *TunnelServer) serveConnect(resp http.ResponseWriter, req *http.Request) } resp.WriteHeader(http.StatusOK) - rconn, _, err := hijacker.Hijack() + rconn, bufrw, err := hijacker.Hijack() if err != nil { responsewriters.ErrorNegotiated( apierrors.NewInternalError(err), @@ -197,7 +199,7 @@ func (t *TunnelServer) serveConnect(resp http.ResponseWriter, req *http.Request) return } - proxy.Proxy(rconn, bconn) + proxy.Proxy(newConnReadWriteCloser(rconn, bufrw), bconn) } // dialBackend determines where to route the connection request to, and returns @@ -270,3 +272,32 @@ func (t *TunnelServer) dialBackend(ctx context.Context, addr string) (net.Conn, logrus.Debugf("Tunnel server egress proxy dialing %s directly", addr) return defaultDialer.DialContext(ctx, "tcp", addr) } + +// connReadWriteCloser bundles a net.Conn and a wrapping bufio.ReadWriter together into a type that +// meets the ReadWriteCloser interface. The http.Hijacker interface returns such a pair, and reads +// need to go through the buffered reader (because the http handler may have already read from the +// underlying connection), but writes and closes need to hit the connection directly. +type connReadWriteCloser struct { + conn net.Conn + once sync.Once + rw *bufio.ReadWriter +} + +var _ io.ReadWriteCloser = &connReadWriteCloser{} + +func newConnReadWriteCloser(conn net.Conn, rw *bufio.ReadWriter) *connReadWriteCloser { + return &connReadWriteCloser{conn: conn, rw: rw} +} + +func (crw *connReadWriteCloser) Read(p []byte) (n int, err error) { + return crw.rw.Read(p) +} + +func (crw *connReadWriteCloser) Write(b []byte) (n int, err error) { + return crw.conn.Write(b) +} + +func (crw *connReadWriteCloser) Close() (err error) { + crw.once.Do(func() { err = crw.conn.Close() }) + return +}