From 1d037f2f1f7a824903a27401f9fb26fb5bc3adf4 Mon Sep 17 00:00:00 2001 From: andres-portainer <91705312+andres-portainer@users.noreply.github.com> Date: Fri, 25 Oct 2024 11:21:49 -0300 Subject: [PATCH] feat(websocket): improve websocket code sharing BE-11340 (#61) --- api/http/handler/websocket/attach.go | 17 +--- api/http/handler/websocket/exec.go | 16 +--- api/http/handler/websocket/pod.go | 5 +- api/http/handler/websocket/stream.go | 70 --------------- api/{http/handler/websocket => ws}/hijack.go | 90 ++++++++------------ api/ws/stream.go | 77 +++++++++++++++++ 6 files changed, 122 insertions(+), 153 deletions(-) delete mode 100644 api/http/handler/websocket/stream.go rename api/{http/handler/websocket => ws}/hijack.go (52%) create mode 100644 api/ws/stream.go diff --git a/api/http/handler/websocket/attach.go b/api/http/handler/websocket/attach.go index 17c94a29f..d0cb7746f 100644 --- a/api/http/handler/websocket/attach.go +++ b/api/http/handler/websocket/attach.go @@ -5,10 +5,8 @@ import ( "net/http" "time" - "github.com/portainer/portainer/api/http/security" - "github.com/rs/zerolog/log" - portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/ws" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" @@ -76,14 +74,6 @@ func (handler *Handler) websocketAttach(w http.ResponseWriter, r *http.Request) } func (handler *Handler) handleAttachRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error { - tokenData, err := security.RetrieveTokenData(r) - if err != nil { - log.Warn(). - Err(err). - Msg("unable to retrieve user details from authentication token") - return err - } - r.Header.Del("Origin") if params.endpoint.Type == portainer.AgentOnDockerEnvironment { @@ -98,14 +88,13 @@ func (handler *Handler) handleAttachRequest(w http.ResponseWriter, r *http.Reque } defer websocketConn.Close() - return hijackAttachStartOperation(websocketConn, params.endpoint, params.ID, tokenData.Token) + return hijackAttachStartOperation(websocketConn, params.endpoint, params.ID) } func hijackAttachStartOperation( websocketConn *websocket.Conn, endpoint *portainer.Endpoint, attachID string, - token string, ) error { conn, err := initDial(endpoint) if err != nil { @@ -127,7 +116,7 @@ func hijackAttachStartOperation( return err } - return hijackRequest(websocketConn, conn, attachStartRequest, token) + return ws.HijackRequest(websocketConn, conn, attachStartRequest) } func createAttachStartRequest(attachID string) (*http.Request, error) { diff --git a/api/http/handler/websocket/exec.go b/api/http/handler/websocket/exec.go index 1c2528dba..ab04b0702 100644 --- a/api/http/handler/websocket/exec.go +++ b/api/http/handler/websocket/exec.go @@ -5,13 +5,12 @@ import ( "net/http" portainer "github.com/portainer/portainer/api" - "github.com/portainer/portainer/api/http/security" + "github.com/portainer/portainer/api/ws" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" "github.com/asaskevich/govalidator" "github.com/gorilla/websocket" - "github.com/rs/zerolog/log" "github.com/segmentio/encoding/json" ) @@ -79,14 +78,6 @@ func (handler *Handler) websocketExec(w http.ResponseWriter, r *http.Request) *h } func (handler *Handler) handleExecRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error { - tokenData, err := security.RetrieveTokenData(r) - if err != nil { - log.Warn(). - Err(err). - Msg("unable to retrieve user details from authentication token") - return err - } - r.Header.Del("Origin") if params.endpoint.Type == portainer.AgentOnDockerEnvironment { @@ -102,14 +93,13 @@ func (handler *Handler) handleExecRequest(w http.ResponseWriter, r *http.Request defer websocketConn.Close() - return hijackExecStartOperation(websocketConn, params.endpoint, params.ID, tokenData.Token) + return hijackExecStartOperation(websocketConn, params.endpoint, params.ID) } func hijackExecStartOperation( websocketConn *websocket.Conn, endpoint *portainer.Endpoint, execID string, - token string, ) error { conn, err := initDial(endpoint) if err != nil { @@ -121,7 +111,7 @@ func hijackExecStartOperation( return err } - return hijackRequest(websocketConn, conn, execStartRequest, token) + return ws.HijackRequest(websocketConn, conn, execStartRequest) } func createExecStartRequest(execID string) (*http.Request, error) { diff --git a/api/http/handler/websocket/pod.go b/api/http/handler/websocket/pod.go index eb17ff9f7..9d7969682 100644 --- a/api/http/handler/websocket/pod.go +++ b/api/http/handler/websocket/pod.go @@ -9,6 +9,7 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/http/proxy/factory/kubernetes" "github.com/portainer/portainer/api/http/security" + "github.com/portainer/portainer/api/ws" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" @@ -136,8 +137,8 @@ func (handler *Handler) hijackPodExecStartOperation( // errorChan is used to propagate errors from the go routines to the caller. errorChan := make(chan error, 1) - go streamFromWebsocketToWriter(websocketConn, stdinWriter, errorChan) - go streamFromReaderToWebsocket(websocketConn, stdoutReader, errorChan) + go ws.StreamFromWebsocketToWriter(websocketConn, stdinWriter, errorChan) + go ws.StreamFromReaderToWebsocket(websocketConn, stdoutReader, errorChan) // StartExecProcess is a blocking operation which streams IO to/from pod; // this must execute in asynchronously, since the websocketConn could return errors (e.g. client disconnects) before diff --git a/api/http/handler/websocket/stream.go b/api/http/handler/websocket/stream.go deleted file mode 100644 index 77ce78566..000000000 --- a/api/http/handler/websocket/stream.go +++ /dev/null @@ -1,70 +0,0 @@ -package websocket - -import ( - "io" - "unicode/utf8" - - "github.com/gorilla/websocket" -) - -const readerBufferSize = 2048 - -func streamFromWebsocketToWriter(websocketConn *websocket.Conn, writer io.Writer, errorChan chan error) { - for { - _, in, err := websocketConn.ReadMessage() - if err != nil { - errorChan <- err - - break - } - - _, err = writer.Write(in) - if err != nil { - errorChan <- err - - break - } - } -} - -func streamFromReaderToWebsocket(websocketConn *websocket.Conn, reader io.Reader, errorChan chan error) { - out := make([]byte, readerBufferSize) - - for { - n, err := reader.Read(out) - if err != nil { - errorChan <- err - - break - } - - processedOutput := validString(string(out[:n])) - err = websocketConn.WriteMessage(websocket.TextMessage, []byte(processedOutput)) - if err != nil { - errorChan <- err - - break - } - } -} - -func validString(s string) string { - if utf8.ValidString(s) { - return s - } - - v := make([]rune, 0, len(s)) - - for i, r := range s { - if r == utf8.RuneError { - _, size := utf8.DecodeRuneInString(s[i:]) - if size == 1 { - continue - } - } - - v = append(v, r) - } - - return string(v) -} diff --git a/api/http/handler/websocket/hijack.go b/api/ws/hijack.go similarity index 52% rename from api/http/handler/websocket/hijack.go rename to api/ws/hijack.go index b080c2095..75f01ea2e 100644 --- a/api/http/handler/websocket/hijack.go +++ b/api/ws/hijack.go @@ -1,4 +1,4 @@ -package websocket +package ws import ( "bufio" @@ -16,18 +16,13 @@ import ( const ( // Time allowed to write a message to the peer - writeWait = 10 * time.Second + WriteWait = 10 * time.Second // Send pings to peer with this period - pingPeriod = 50 * time.Second + PingPeriod = 50 * time.Second ) -func hijackRequest( - websocketConn *websocket.Conn, - conn net.Conn, - request *http.Request, - token string, -) error { +func HijackRequest(websocketConn *websocket.Conn, conn net.Conn, request *http.Request) error { resp, err := sendHTTPRequest(conn, request) if err != nil { return err @@ -39,17 +34,21 @@ func hijackRequest( return fmt.Errorf("unexpected response status code: %d", resp.StatusCode) } + var mu sync.Mutex + errorChan := make(chan error, 1) - go readWebSocketToTCP(websocketConn, conn, errorChan) - go writeTCPToWebSocket(websocketConn, conn, errorChan) + go StreamFromWebsocketToWriter(websocketConn, conn, errorChan) + go WriteReaderToWebSocket(websocketConn, &mu, conn, errorChan) err = <-errorChan if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { - log.Debug().Msgf("Unexpected close error: %v\n", err) + log.Debug().Err(err).Msg("unexpected close error") + return err } - log.Debug().Msgf("session ended") + log.Info().Msg("session ended") + return nil } @@ -69,60 +68,40 @@ func sendHTTPRequest(conn net.Conn, req *http.Request) (*http.Response, error) { return resp, nil } -func readWebSocketToTCP(websocketConn *websocket.Conn, tcpConn net.Conn, errorChan chan error) { - for { - messageType, p, err := websocketConn.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { - log.Debug().Msgf("Unexpected close error: %v\n", err) - } - errorChan <- err - return - } - - if messageType == websocket.TextMessage || messageType == websocket.BinaryMessage { - _, err := tcpConn.Write(p) - if err != nil { - log.Debug().Msgf("Error writing to TCP connection: %v\n", err) - errorChan <- err - return - } - } - } -} - -func writeTCPToWebSocket(websocketConn *websocket.Conn, tcpConn net.Conn, errorChan chan error) { - var mu sync.Mutex - out := make([]byte, readerBufferSize) +func WriteReaderToWebSocket(websocketConn *websocket.Conn, mu *sync.Mutex, reader io.Reader, errorChan chan error) { + out := make([]byte, ReaderBufferSize) input := make(chan string) - pingTicker := time.NewTicker(pingPeriod) + pingTicker := time.NewTicker(PingPeriod) defer pingTicker.Stop() defer websocketConn.Close() - websocketConn.SetReadLimit(2048) + mu.Lock() + websocketConn.SetReadLimit(ReaderBufferSize) websocketConn.SetPongHandler(func(string) error { return nil }) websocketConn.SetPingHandler(func(data string) error { - websocketConn.SetWriteDeadline(time.Now().Add(writeWait)) + websocketConn.SetWriteDeadline(time.Now().Add(WriteWait)) + return websocketConn.WriteMessage(websocket.PongMessage, []byte(data)) }) - - reader := bufio.NewReader(tcpConn) + mu.Unlock() go func() { for { n, err := reader.Read(out) if err != nil { errorChan <- err + if !errors.Is(err, io.EOF) { - log.Debug().Msgf("error reading from server: %v", err) + log.Debug().Err(err).Msg("error reading from server") } + return } - processedOutput := validString(string(out[:n])) + processedOutput := ValidString(string(out[:n])) input <- processedOutput } }() @@ -130,34 +109,37 @@ func writeTCPToWebSocket(websocketConn *websocket.Conn, tcpConn net.Conn, errorC for { select { case msg := <-input: - err := wswrite(websocketConn, &mu, msg) - if err != nil { - log.Debug().Msgf("error writing to websocket: %v", err) + if err := wsWrite(websocketConn, mu, msg); err != nil { + log.Debug().Err(err).Msg("error writing to websocket") errorChan <- err + return } case <-pingTicker.C: - if err := wsping(websocketConn, &mu); err != nil { - log.Debug().Msgf("error writing to websocket during pong response: %v", err) + if err := wsPing(websocketConn, mu); err != nil { + log.Debug().Err(err).Msg("error writing to websocket during pong response") errorChan <- err + return } } } } -func wswrite(websocketConn *websocket.Conn, mu *sync.Mutex, msg string) error { +func wsWrite(websocketConn *websocket.Conn, mu *sync.Mutex, msg string) error { mu.Lock() defer mu.Unlock() - websocketConn.SetWriteDeadline(time.Now().Add(writeWait)) + websocketConn.SetWriteDeadline(time.Now().Add(WriteWait)) + return websocketConn.WriteMessage(websocket.TextMessage, []byte(msg)) } -func wsping(websocketConn *websocket.Conn, mu *sync.Mutex) error { +func wsPing(websocketConn *websocket.Conn, mu *sync.Mutex) error { mu.Lock() defer mu.Unlock() - websocketConn.SetWriteDeadline(time.Now().Add(writeWait)) + websocketConn.SetWriteDeadline(time.Now().Add(WriteWait)) + return websocketConn.WriteMessage(websocket.PingMessage, nil) } diff --git a/api/ws/stream.go b/api/ws/stream.go new file mode 100644 index 000000000..b4e6bc96a --- /dev/null +++ b/api/ws/stream.go @@ -0,0 +1,77 @@ +package ws + +import ( + "io" + "unicode/utf8" + + "github.com/gorilla/websocket" + "github.com/rs/zerolog/log" +) + +const ReaderBufferSize = 2048 + +func StreamFromWebsocketToWriter(websocketConn *websocket.Conn, writer io.Writer, errorChan chan error) { + for { + messageType, in, err := websocketConn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { + log.Debug().Err(err).Msg("unexpected close error") + } + errorChan <- err + + return + } + + if messageType != websocket.TextMessage && messageType != websocket.BinaryMessage { + continue + } + + if _, err := writer.Write(in); err != nil { + log.Debug().Err(err).Msg("writing error") + errorChan <- err + + return + } + } +} + +func StreamFromReaderToWebsocket(websocketConn *websocket.Conn, reader io.Reader, errorChan chan error) { + out := make([]byte, ReaderBufferSize) + + for { + n, err := reader.Read(out) + if err != nil { + errorChan <- err + + break + } + + processedOutput := ValidString(string(out[:n])) + if err := websocketConn.WriteMessage(websocket.TextMessage, []byte(processedOutput)); err != nil { + errorChan <- err + + break + } + } +} + +func ValidString(s string) string { + if utf8.ValidString(s) { + return s + } + + v := make([]rune, 0, len(s)) + + for i, r := range s { + if r == utf8.RuneError { + _, size := utf8.DecodeRuneInString(s[i:]) + if size == 1 { + continue + } + } + + v = append(v, r) + } + + return string(v) +}