diff --git a/api/http/handler/websocket/attach.go b/api/http/handler/websocket/attach.go index b023f7bb5..17c94a29f 100644 --- a/api/http/handler/websocket/attach.go +++ b/api/http/handler/websocket/attach.go @@ -1,13 +1,13 @@ package websocket import ( - "github.com/portainer/portainer/api/http/security" - "github.com/rs/zerolog/log" "net" "net/http" - "net/http/httputil" "time" + "github.com/portainer/portainer/api/http/security" + "github.com/rs/zerolog/log" + portainer "github.com/portainer/portainer/api" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" @@ -107,7 +107,7 @@ func hijackAttachStartOperation( attachID string, token string, ) error { - dial, err := initDial(endpoint) + conn, err := initDial(endpoint) if err != nil { return err } @@ -117,24 +117,20 @@ func hijackAttachStartOperation( // network setups may cause ECONNTIMEOUT, leaving the client in an unknown // state. Setting TCP KeepAlive on the socket connection will prohibit // ECONNTIMEOUT unless the socket connection truly is broken - if tcpConn, ok := dial.(*net.TCPConn); ok { + if tcpConn, ok := conn.(*net.TCPConn); ok { tcpConn.SetKeepAlive(true) tcpConn.SetKeepAlivePeriod(30 * time.Second) } - httpConn := httputil.NewClientConn(dial, nil) - defer httpConn.Close() - attachStartRequest, err := createAttachStartRequest(attachID) if err != nil { return err } - return hijackRequest(websocketConn, httpConn, attachStartRequest, token) + return hijackRequest(websocketConn, conn, attachStartRequest, token) } func createAttachStartRequest(attachID string) (*http.Request, error) { - request, err := http.NewRequest("POST", "/containers/"+attachID+"/attach?stdin=1&stdout=1&stderr=1&stream=1", nil) if err != nil { return nil, err diff --git a/api/http/handler/websocket/exec.go b/api/http/handler/websocket/exec.go index be66e255b..1c2528dba 100644 --- a/api/http/handler/websocket/exec.go +++ b/api/http/handler/websocket/exec.go @@ -2,10 +2,7 @@ package websocket import ( "bytes" - "net" "net/http" - "net/http/httputil" - "time" portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/http/security" @@ -102,6 +99,7 @@ func (handler *Handler) handleExecRequest(w http.ResponseWriter, r *http.Request if err != nil { return err } + defer websocketConn.Close() return hijackExecStartOperation(websocketConn, params.endpoint, params.ID, tokenData.Token) @@ -113,30 +111,17 @@ func hijackExecStartOperation( execID string, token string, ) error { - dial, err := initDial(endpoint) + conn, err := initDial(endpoint) if err != nil { return err } - // When we set up a TCP connection for hijack, there could be long periods - // of inactivity (a long running command with no output) that in certain - // network setups may cause ECONNTIMEOUT, leaving the client in an unknown - // state. Setting TCP KeepAlive on the socket connection will prohibit - // ECONNTIMEOUT unless the socket connection truly is broken - if tcpConn, ok := dial.(*net.TCPConn); ok { - tcpConn.SetKeepAlive(true) - tcpConn.SetKeepAlivePeriod(30 * time.Second) - } - - httpConn := httputil.NewClientConn(dial, nil) - defer httpConn.Close() - execStartRequest, err := createExecStartRequest(execID) if err != nil { return err } - return hijackRequest(websocketConn, httpConn, execStartRequest, token) + return hijackRequest(websocketConn, conn, execStartRequest, token) } func createExecStartRequest(execID string) (*http.Request, error) { diff --git a/api/http/handler/websocket/hijack.go b/api/http/handler/websocket/hijack.go index ccda8ac8f..a41eb638a 100644 --- a/api/http/handler/websocket/hijack.go +++ b/api/http/handler/websocket/hijack.go @@ -1,50 +1,163 @@ package websocket import ( + "bufio" "errors" "fmt" + "io" + "net" "net/http" - "net/http/httputil" + "sync" + "time" "github.com/gorilla/websocket" - "github.com/portainer/portainer/api/internal/logoutcontext" + "github.com/rs/zerolog/log" +) + +const ( + // Time allowed to write a message to the peer + writeWait = 10 * time.Second + + // Send pings to peer with this period + pingPeriod = 50 * time.Second ) func hijackRequest( websocketConn *websocket.Conn, - httpConn *httputil.ClientConn, + conn net.Conn, request *http.Request, token string, ) error { - // Server hijacks the connection, error 'connection closed' expected - resp, err := httpConn.Do(request) - if !errors.Is(err, httputil.ErrPersistEOF) { - if err != nil { - return err - } - if resp.StatusCode != http.StatusSwitchingProtocols { - resp.Body.Close() - return fmt.Errorf("unable to upgrade to tcp, received %d", resp.StatusCode) - } + resp, err := sendHTTPRequest(conn, request) + if err != nil { + return err } + defer resp.Body.Close() - tcpConn, brw := httpConn.Hijack() - defer tcpConn.Close() + // Check if the response status code indicates an upgrade (101 Switching Protocols) + if resp.StatusCode != http.StatusSwitchingProtocols { + return fmt.Errorf("unexpected response status code: %d", resp.StatusCode) + } errorChan := make(chan error, 1) - go streamFromReaderToWebsocket(websocketConn, brw, errorChan) - go streamFromWebsocketToWriter(websocketConn, tcpConn, errorChan) + go readWebSocketToTCP(websocketConn, conn, errorChan) + go writeTCPToWebSocket(websocketConn, conn, errorChan) - logoutCtx := logoutcontext.GetContext(token) - - select { - case <-logoutCtx.Done(): - return fmt.Errorf("Your session has been logged out.") - case err = <-errorChan: - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { - return err - } + err = <-errorChan + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { + log.Debug().Msgf("Unexpected close error: %v\n", err) + return err } + log.Debug().Msgf("session ended") return nil } + +// sendHTTPRequest sends an HTTP request over the provided net.Conn and parses the response. +func sendHTTPRequest(conn net.Conn, req *http.Request) (*http.Response, error) { + // Send the HTTP request to the server + if err := req.Write(conn); err != nil { + return nil, fmt.Errorf("error writing request: %w", err) + } + + // Read the response from the server + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + return nil, fmt.Errorf("error reading response: %w", err) + } + + return resp, nil +} + +func readWebSocketToTCP(websocketConn *websocket.Conn, tcpConn net.Conn, errorChan chan error) { + for { + messageType, p, err := websocketConn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { + log.Debug().Msgf("Unexpected close error: %v\n", err) + } + errorChan <- err + return + } + + if messageType == websocket.TextMessage || messageType == websocket.BinaryMessage { + _, err := tcpConn.Write(p) + if err != nil { + log.Debug().Msgf("Error writing to TCP connection: %v\n", err) + errorChan <- err + return + } + } + } +} + +func writeTCPToWebSocket(websocketConn *websocket.Conn, tcpConn net.Conn, errorChan chan error) { + var mu sync.Mutex + out := make([]byte, readerBufferSize) + input := make(chan string) + pingTicker := time.NewTicker(pingPeriod) + defer pingTicker.Stop() + defer websocketConn.Close() + + websocketConn.SetReadLimit(2048) + websocketConn.SetPongHandler(func(string) error { + return nil + }) + + websocketConn.SetPingHandler(func(data string) error { + websocketConn.SetWriteDeadline(time.Now().Add(writeWait)) + return websocketConn.WriteMessage(websocket.PongMessage, []byte(data)) + }) + + reader := bufio.NewReader(tcpConn) + + go func() { + for { + n, err := reader.Read(out) + if err != nil { + errorChan <- err + if !errors.Is(err, io.EOF) { + log.Debug().Msgf("error reading from server: %v", err) + } + return + } + + processedOutput := validString(string(out[:n])) + input <- string(processedOutput) + } + }() + + for { + select { + case msg := <-input: + err := wswrite(websocketConn, &mu, msg) + if err != nil { + log.Debug().Msgf("error writing to websocket: %v", err) + errorChan <- err + return + } + case <-pingTicker.C: + if err := wsping(websocketConn, &mu); err != nil { + log.Debug().Msgf("error writing to websocket during pong response: %v", err) + errorChan <- err + return + } + } + } +} + +func wswrite(websocketConn *websocket.Conn, mu *sync.Mutex, msg string) error { + mu.Lock() + defer mu.Unlock() + + websocketConn.SetWriteDeadline(time.Now().Add(writeWait)) + return websocketConn.WriteMessage(websocket.TextMessage, []byte(msg)) +} + +func wsping(websocketConn *websocket.Conn, mu *sync.Mutex) error { + mu.Lock() + defer mu.Unlock() + + websocketConn.SetWriteDeadline(time.Now().Add(writeWait)) + return websocketConn.WriteMessage(websocket.PingMessage, nil) +}