feat(websocket): improve websocket code sharing BE-11340 (#61)

pull/12336/head^2
andres-portainer 2024-10-25 11:21:49 -03:00 committed by GitHub
parent b2d67795b3
commit 1d037f2f1f
6 changed files with 122 additions and 153 deletions

View File

@ -5,10 +5,8 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/portainer/portainer/api/http/security"
"github.com/rs/zerolog/log"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/ws"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
@ -76,14 +74,6 @@ func (handler *Handler) websocketAttach(w http.ResponseWriter, r *http.Request)
} }
func (handler *Handler) handleAttachRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error { func (handler *Handler) handleAttachRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error {
tokenData, err := security.RetrieveTokenData(r)
if err != nil {
log.Warn().
Err(err).
Msg("unable to retrieve user details from authentication token")
return err
}
r.Header.Del("Origin") r.Header.Del("Origin")
if params.endpoint.Type == portainer.AgentOnDockerEnvironment { if params.endpoint.Type == portainer.AgentOnDockerEnvironment {
@ -98,14 +88,13 @@ func (handler *Handler) handleAttachRequest(w http.ResponseWriter, r *http.Reque
} }
defer websocketConn.Close() defer websocketConn.Close()
return hijackAttachStartOperation(websocketConn, params.endpoint, params.ID, tokenData.Token) return hijackAttachStartOperation(websocketConn, params.endpoint, params.ID)
} }
func hijackAttachStartOperation( func hijackAttachStartOperation(
websocketConn *websocket.Conn, websocketConn *websocket.Conn,
endpoint *portainer.Endpoint, endpoint *portainer.Endpoint,
attachID string, attachID string,
token string,
) error { ) error {
conn, err := initDial(endpoint) conn, err := initDial(endpoint)
if err != nil { if err != nil {
@ -127,7 +116,7 @@ func hijackAttachStartOperation(
return err return err
} }
return hijackRequest(websocketConn, conn, attachStartRequest, token) return ws.HijackRequest(websocketConn, conn, attachStartRequest)
} }
func createAttachStartRequest(attachID string) (*http.Request, error) { func createAttachStartRequest(attachID string) (*http.Request, error) {

View File

@ -5,13 +5,12 @@ import (
"net/http" "net/http"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/http/security" "github.com/portainer/portainer/api/ws"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/asaskevich/govalidator" "github.com/asaskevich/govalidator"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/rs/zerolog/log"
"github.com/segmentio/encoding/json" "github.com/segmentio/encoding/json"
) )
@ -79,14 +78,6 @@ func (handler *Handler) websocketExec(w http.ResponseWriter, r *http.Request) *h
} }
func (handler *Handler) handleExecRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error { func (handler *Handler) handleExecRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error {
tokenData, err := security.RetrieveTokenData(r)
if err != nil {
log.Warn().
Err(err).
Msg("unable to retrieve user details from authentication token")
return err
}
r.Header.Del("Origin") r.Header.Del("Origin")
if params.endpoint.Type == portainer.AgentOnDockerEnvironment { if params.endpoint.Type == portainer.AgentOnDockerEnvironment {
@ -102,14 +93,13 @@ func (handler *Handler) handleExecRequest(w http.ResponseWriter, r *http.Request
defer websocketConn.Close() defer websocketConn.Close()
return hijackExecStartOperation(websocketConn, params.endpoint, params.ID, tokenData.Token) return hijackExecStartOperation(websocketConn, params.endpoint, params.ID)
} }
func hijackExecStartOperation( func hijackExecStartOperation(
websocketConn *websocket.Conn, websocketConn *websocket.Conn,
endpoint *portainer.Endpoint, endpoint *portainer.Endpoint,
execID string, execID string,
token string,
) error { ) error {
conn, err := initDial(endpoint) conn, err := initDial(endpoint)
if err != nil { if err != nil {
@ -121,7 +111,7 @@ func hijackExecStartOperation(
return err return err
} }
return hijackRequest(websocketConn, conn, execStartRequest, token) return ws.HijackRequest(websocketConn, conn, execStartRequest)
} }
func createExecStartRequest(execID string) (*http.Request, error) { func createExecStartRequest(execID string) (*http.Request, error) {

View File

@ -9,6 +9,7 @@ import (
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/http/proxy/factory/kubernetes" "github.com/portainer/portainer/api/http/proxy/factory/kubernetes"
"github.com/portainer/portainer/api/http/security" "github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/ws"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
@ -136,8 +137,8 @@ func (handler *Handler) hijackPodExecStartOperation(
// errorChan is used to propagate errors from the go routines to the caller. // errorChan is used to propagate errors from the go routines to the caller.
errorChan := make(chan error, 1) errorChan := make(chan error, 1)
go streamFromWebsocketToWriter(websocketConn, stdinWriter, errorChan) go ws.StreamFromWebsocketToWriter(websocketConn, stdinWriter, errorChan)
go streamFromReaderToWebsocket(websocketConn, stdoutReader, errorChan) go ws.StreamFromReaderToWebsocket(websocketConn, stdoutReader, errorChan)
// StartExecProcess is a blocking operation which streams IO to/from pod; // StartExecProcess is a blocking operation which streams IO to/from pod;
// this must execute in asynchronously, since the websocketConn could return errors (e.g. client disconnects) before // this must execute in asynchronously, since the websocketConn could return errors (e.g. client disconnects) before

View File

@ -1,70 +0,0 @@
package websocket
import (
"io"
"unicode/utf8"
"github.com/gorilla/websocket"
)
const readerBufferSize = 2048
func streamFromWebsocketToWriter(websocketConn *websocket.Conn, writer io.Writer, errorChan chan error) {
for {
_, in, err := websocketConn.ReadMessage()
if err != nil {
errorChan <- err
break
}
_, err = writer.Write(in)
if err != nil {
errorChan <- err
break
}
}
}
func streamFromReaderToWebsocket(websocketConn *websocket.Conn, reader io.Reader, errorChan chan error) {
out := make([]byte, readerBufferSize)
for {
n, err := reader.Read(out)
if err != nil {
errorChan <- err
break
}
processedOutput := validString(string(out[:n]))
err = websocketConn.WriteMessage(websocket.TextMessage, []byte(processedOutput))
if err != nil {
errorChan <- err
break
}
}
}
func validString(s string) string {
if utf8.ValidString(s) {
return s
}
v := make([]rune, 0, len(s))
for i, r := range s {
if r == utf8.RuneError {
_, size := utf8.DecodeRuneInString(s[i:])
if size == 1 {
continue
}
}
v = append(v, r)
}
return string(v)
}

View File

@ -1,4 +1,4 @@
package websocket package ws
import ( import (
"bufio" "bufio"
@ -16,18 +16,13 @@ import (
const ( const (
// Time allowed to write a message to the peer // Time allowed to write a message to the peer
writeWait = 10 * time.Second WriteWait = 10 * time.Second
// Send pings to peer with this period // Send pings to peer with this period
pingPeriod = 50 * time.Second PingPeriod = 50 * time.Second
) )
func hijackRequest( func HijackRequest(websocketConn *websocket.Conn, conn net.Conn, request *http.Request) error {
websocketConn *websocket.Conn,
conn net.Conn,
request *http.Request,
token string,
) error {
resp, err := sendHTTPRequest(conn, request) resp, err := sendHTTPRequest(conn, request)
if err != nil { if err != nil {
return err return err
@ -39,17 +34,21 @@ func hijackRequest(
return fmt.Errorf("unexpected response status code: %d", resp.StatusCode) return fmt.Errorf("unexpected response status code: %d", resp.StatusCode)
} }
var mu sync.Mutex
errorChan := make(chan error, 1) errorChan := make(chan error, 1)
go readWebSocketToTCP(websocketConn, conn, errorChan) go StreamFromWebsocketToWriter(websocketConn, conn, errorChan)
go writeTCPToWebSocket(websocketConn, conn, errorChan) go WriteReaderToWebSocket(websocketConn, &mu, conn, errorChan)
err = <-errorChan err = <-errorChan
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
log.Debug().Msgf("Unexpected close error: %v\n", err) log.Debug().Err(err).Msg("unexpected close error")
return err return err
} }
log.Debug().Msgf("session ended") log.Info().Msg("session ended")
return nil return nil
} }
@ -69,60 +68,40 @@ func sendHTTPRequest(conn net.Conn, req *http.Request) (*http.Response, error) {
return resp, nil return resp, nil
} }
func readWebSocketToTCP(websocketConn *websocket.Conn, tcpConn net.Conn, errorChan chan error) { func WriteReaderToWebSocket(websocketConn *websocket.Conn, mu *sync.Mutex, reader io.Reader, errorChan chan error) {
for { out := make([]byte, ReaderBufferSize)
messageType, p, err := websocketConn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
log.Debug().Msgf("Unexpected close error: %v\n", err)
}
errorChan <- err
return
}
if messageType == websocket.TextMessage || messageType == websocket.BinaryMessage {
_, err := tcpConn.Write(p)
if err != nil {
log.Debug().Msgf("Error writing to TCP connection: %v\n", err)
errorChan <- err
return
}
}
}
}
func writeTCPToWebSocket(websocketConn *websocket.Conn, tcpConn net.Conn, errorChan chan error) {
var mu sync.Mutex
out := make([]byte, readerBufferSize)
input := make(chan string) input := make(chan string)
pingTicker := time.NewTicker(pingPeriod) pingTicker := time.NewTicker(PingPeriod)
defer pingTicker.Stop() defer pingTicker.Stop()
defer websocketConn.Close() defer websocketConn.Close()
websocketConn.SetReadLimit(2048) mu.Lock()
websocketConn.SetReadLimit(ReaderBufferSize)
websocketConn.SetPongHandler(func(string) error { websocketConn.SetPongHandler(func(string) error {
return nil return nil
}) })
websocketConn.SetPingHandler(func(data string) error { websocketConn.SetPingHandler(func(data string) error {
websocketConn.SetWriteDeadline(time.Now().Add(writeWait)) websocketConn.SetWriteDeadline(time.Now().Add(WriteWait))
return websocketConn.WriteMessage(websocket.PongMessage, []byte(data)) return websocketConn.WriteMessage(websocket.PongMessage, []byte(data))
}) })
mu.Unlock()
reader := bufio.NewReader(tcpConn)
go func() { go func() {
for { for {
n, err := reader.Read(out) n, err := reader.Read(out)
if err != nil { if err != nil {
errorChan <- err errorChan <- err
if !errors.Is(err, io.EOF) { if !errors.Is(err, io.EOF) {
log.Debug().Msgf("error reading from server: %v", err) log.Debug().Err(err).Msg("error reading from server")
} }
return return
} }
processedOutput := validString(string(out[:n])) processedOutput := ValidString(string(out[:n]))
input <- processedOutput input <- processedOutput
} }
}() }()
@ -130,34 +109,37 @@ func writeTCPToWebSocket(websocketConn *websocket.Conn, tcpConn net.Conn, errorC
for { for {
select { select {
case msg := <-input: case msg := <-input:
err := wswrite(websocketConn, &mu, msg) if err := wsWrite(websocketConn, mu, msg); err != nil {
if err != nil { log.Debug().Err(err).Msg("error writing to websocket")
log.Debug().Msgf("error writing to websocket: %v", err)
errorChan <- err errorChan <- err
return return
} }
case <-pingTicker.C: case <-pingTicker.C:
if err := wsping(websocketConn, &mu); err != nil { if err := wsPing(websocketConn, mu); err != nil {
log.Debug().Msgf("error writing to websocket during pong response: %v", err) log.Debug().Err(err).Msg("error writing to websocket during pong response")
errorChan <- err errorChan <- err
return return
} }
} }
} }
} }
func wswrite(websocketConn *websocket.Conn, mu *sync.Mutex, msg string) error { func wsWrite(websocketConn *websocket.Conn, mu *sync.Mutex, msg string) error {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
websocketConn.SetWriteDeadline(time.Now().Add(writeWait)) websocketConn.SetWriteDeadline(time.Now().Add(WriteWait))
return websocketConn.WriteMessage(websocket.TextMessage, []byte(msg)) return websocketConn.WriteMessage(websocket.TextMessage, []byte(msg))
} }
func wsping(websocketConn *websocket.Conn, mu *sync.Mutex) error { func wsPing(websocketConn *websocket.Conn, mu *sync.Mutex) error {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
websocketConn.SetWriteDeadline(time.Now().Add(writeWait)) websocketConn.SetWriteDeadline(time.Now().Add(WriteWait))
return websocketConn.WriteMessage(websocket.PingMessage, nil) return websocketConn.WriteMessage(websocket.PingMessage, nil)
} }

77
api/ws/stream.go Normal file
View File

@ -0,0 +1,77 @@
package ws
import (
"io"
"unicode/utf8"
"github.com/gorilla/websocket"
"github.com/rs/zerolog/log"
)
const ReaderBufferSize = 2048
func StreamFromWebsocketToWriter(websocketConn *websocket.Conn, writer io.Writer, errorChan chan error) {
for {
messageType, in, err := websocketConn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
log.Debug().Err(err).Msg("unexpected close error")
}
errorChan <- err
return
}
if messageType != websocket.TextMessage && messageType != websocket.BinaryMessage {
continue
}
if _, err := writer.Write(in); err != nil {
log.Debug().Err(err).Msg("writing error")
errorChan <- err
return
}
}
}
func StreamFromReaderToWebsocket(websocketConn *websocket.Conn, reader io.Reader, errorChan chan error) {
out := make([]byte, ReaderBufferSize)
for {
n, err := reader.Read(out)
if err != nil {
errorChan <- err
break
}
processedOutput := ValidString(string(out[:n]))
if err := websocketConn.WriteMessage(websocket.TextMessage, []byte(processedOutput)); err != nil {
errorChan <- err
break
}
}
}
func ValidString(s string) string {
if utf8.ValidString(s) {
return s
}
v := make([]rune, 0, len(s))
for i, r := range s {
if r == utf8.RuneError {
_, size := utf8.DecodeRuneInString(s[i:])
if size == 1 {
continue
}
}
v = append(v, r)
}
return string(v)
}