mirror of https://github.com/portainer/portainer
feat(websocket): improve websocket code sharing BE-11340 (#61)
parent
b2d67795b3
commit
1d037f2f1f
|
@ -5,10 +5,8 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/portainer/portainer/api/http/security"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
|
|
||||||
portainer "github.com/portainer/portainer/api"
|
portainer "github.com/portainer/portainer/api"
|
||||||
|
"github.com/portainer/portainer/api/ws"
|
||||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||||
|
|
||||||
|
@ -76,14 +74,6 @@ func (handler *Handler) websocketAttach(w http.ResponseWriter, r *http.Request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler *Handler) handleAttachRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error {
|
func (handler *Handler) handleAttachRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error {
|
||||||
tokenData, err := security.RetrieveTokenData(r)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn().
|
|
||||||
Err(err).
|
|
||||||
Msg("unable to retrieve user details from authentication token")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Header.Del("Origin")
|
r.Header.Del("Origin")
|
||||||
|
|
||||||
if params.endpoint.Type == portainer.AgentOnDockerEnvironment {
|
if params.endpoint.Type == portainer.AgentOnDockerEnvironment {
|
||||||
|
@ -98,14 +88,13 @@ func (handler *Handler) handleAttachRequest(w http.ResponseWriter, r *http.Reque
|
||||||
}
|
}
|
||||||
defer websocketConn.Close()
|
defer websocketConn.Close()
|
||||||
|
|
||||||
return hijackAttachStartOperation(websocketConn, params.endpoint, params.ID, tokenData.Token)
|
return hijackAttachStartOperation(websocketConn, params.endpoint, params.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func hijackAttachStartOperation(
|
func hijackAttachStartOperation(
|
||||||
websocketConn *websocket.Conn,
|
websocketConn *websocket.Conn,
|
||||||
endpoint *portainer.Endpoint,
|
endpoint *portainer.Endpoint,
|
||||||
attachID string,
|
attachID string,
|
||||||
token string,
|
|
||||||
) error {
|
) error {
|
||||||
conn, err := initDial(endpoint)
|
conn, err := initDial(endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -127,7 +116,7 @@ func hijackAttachStartOperation(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return hijackRequest(websocketConn, conn, attachStartRequest, token)
|
return ws.HijackRequest(websocketConn, conn, attachStartRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createAttachStartRequest(attachID string) (*http.Request, error) {
|
func createAttachStartRequest(attachID string) (*http.Request, error) {
|
||||||
|
|
|
@ -5,13 +5,12 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
portainer "github.com/portainer/portainer/api"
|
portainer "github.com/portainer/portainer/api"
|
||||||
"github.com/portainer/portainer/api/http/security"
|
"github.com/portainer/portainer/api/ws"
|
||||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||||
|
|
||||||
"github.com/asaskevich/govalidator"
|
"github.com/asaskevich/govalidator"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"github.com/segmentio/encoding/json"
|
"github.com/segmentio/encoding/json"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -79,14 +78,6 @@ func (handler *Handler) websocketExec(w http.ResponseWriter, r *http.Request) *h
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler *Handler) handleExecRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error {
|
func (handler *Handler) handleExecRequest(w http.ResponseWriter, r *http.Request, params *webSocketRequestParams) error {
|
||||||
tokenData, err := security.RetrieveTokenData(r)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn().
|
|
||||||
Err(err).
|
|
||||||
Msg("unable to retrieve user details from authentication token")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Header.Del("Origin")
|
r.Header.Del("Origin")
|
||||||
|
|
||||||
if params.endpoint.Type == portainer.AgentOnDockerEnvironment {
|
if params.endpoint.Type == portainer.AgentOnDockerEnvironment {
|
||||||
|
@ -102,14 +93,13 @@ func (handler *Handler) handleExecRequest(w http.ResponseWriter, r *http.Request
|
||||||
|
|
||||||
defer websocketConn.Close()
|
defer websocketConn.Close()
|
||||||
|
|
||||||
return hijackExecStartOperation(websocketConn, params.endpoint, params.ID, tokenData.Token)
|
return hijackExecStartOperation(websocketConn, params.endpoint, params.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func hijackExecStartOperation(
|
func hijackExecStartOperation(
|
||||||
websocketConn *websocket.Conn,
|
websocketConn *websocket.Conn,
|
||||||
endpoint *portainer.Endpoint,
|
endpoint *portainer.Endpoint,
|
||||||
execID string,
|
execID string,
|
||||||
token string,
|
|
||||||
) error {
|
) error {
|
||||||
conn, err := initDial(endpoint)
|
conn, err := initDial(endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -121,7 +111,7 @@ func hijackExecStartOperation(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return hijackRequest(websocketConn, conn, execStartRequest, token)
|
return ws.HijackRequest(websocketConn, conn, execStartRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createExecStartRequest(execID string) (*http.Request, error) {
|
func createExecStartRequest(execID string) (*http.Request, error) {
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
portainer "github.com/portainer/portainer/api"
|
portainer "github.com/portainer/portainer/api"
|
||||||
"github.com/portainer/portainer/api/http/proxy/factory/kubernetes"
|
"github.com/portainer/portainer/api/http/proxy/factory/kubernetes"
|
||||||
"github.com/portainer/portainer/api/http/security"
|
"github.com/portainer/portainer/api/http/security"
|
||||||
|
"github.com/portainer/portainer/api/ws"
|
||||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||||
|
|
||||||
|
@ -136,8 +137,8 @@ func (handler *Handler) hijackPodExecStartOperation(
|
||||||
|
|
||||||
// errorChan is used to propagate errors from the go routines to the caller.
|
// errorChan is used to propagate errors from the go routines to the caller.
|
||||||
errorChan := make(chan error, 1)
|
errorChan := make(chan error, 1)
|
||||||
go streamFromWebsocketToWriter(websocketConn, stdinWriter, errorChan)
|
go ws.StreamFromWebsocketToWriter(websocketConn, stdinWriter, errorChan)
|
||||||
go streamFromReaderToWebsocket(websocketConn, stdoutReader, errorChan)
|
go ws.StreamFromReaderToWebsocket(websocketConn, stdoutReader, errorChan)
|
||||||
|
|
||||||
// StartExecProcess is a blocking operation which streams IO to/from pod;
|
// StartExecProcess is a blocking operation which streams IO to/from pod;
|
||||||
// this must execute in asynchronously, since the websocketConn could return errors (e.g. client disconnects) before
|
// this must execute in asynchronously, since the websocketConn could return errors (e.g. client disconnects) before
|
||||||
|
|
|
@ -1,70 +0,0 @@
|
||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"unicode/utf8"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
)
|
|
||||||
|
|
||||||
const readerBufferSize = 2048
|
|
||||||
|
|
||||||
func streamFromWebsocketToWriter(websocketConn *websocket.Conn, writer io.Writer, errorChan chan error) {
|
|
||||||
for {
|
|
||||||
_, in, err := websocketConn.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
errorChan <- err
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = writer.Write(in)
|
|
||||||
if err != nil {
|
|
||||||
errorChan <- err
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamFromReaderToWebsocket(websocketConn *websocket.Conn, reader io.Reader, errorChan chan error) {
|
|
||||||
out := make([]byte, readerBufferSize)
|
|
||||||
|
|
||||||
for {
|
|
||||||
n, err := reader.Read(out)
|
|
||||||
if err != nil {
|
|
||||||
errorChan <- err
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
processedOutput := validString(string(out[:n]))
|
|
||||||
err = websocketConn.WriteMessage(websocket.TextMessage, []byte(processedOutput))
|
|
||||||
if err != nil {
|
|
||||||
errorChan <- err
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func validString(s string) string {
|
|
||||||
if utf8.ValidString(s) {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
v := make([]rune, 0, len(s))
|
|
||||||
|
|
||||||
for i, r := range s {
|
|
||||||
if r == utf8.RuneError {
|
|
||||||
_, size := utf8.DecodeRuneInString(s[i:])
|
|
||||||
if size == 1 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
v = append(v, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(v)
|
|
||||||
}
|
|
|
@ -1,4 +1,4 @@
|
||||||
package websocket
|
package ws
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
@ -16,18 +16,13 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// Time allowed to write a message to the peer
|
// Time allowed to write a message to the peer
|
||||||
writeWait = 10 * time.Second
|
WriteWait = 10 * time.Second
|
||||||
|
|
||||||
// Send pings to peer with this period
|
// Send pings to peer with this period
|
||||||
pingPeriod = 50 * time.Second
|
PingPeriod = 50 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
func hijackRequest(
|
func HijackRequest(websocketConn *websocket.Conn, conn net.Conn, request *http.Request) error {
|
||||||
websocketConn *websocket.Conn,
|
|
||||||
conn net.Conn,
|
|
||||||
request *http.Request,
|
|
||||||
token string,
|
|
||||||
) error {
|
|
||||||
resp, err := sendHTTPRequest(conn, request)
|
resp, err := sendHTTPRequest(conn, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -39,17 +34,21 @@ func hijackRequest(
|
||||||
return fmt.Errorf("unexpected response status code: %d", resp.StatusCode)
|
return fmt.Errorf("unexpected response status code: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
errorChan := make(chan error, 1)
|
errorChan := make(chan error, 1)
|
||||||
go readWebSocketToTCP(websocketConn, conn, errorChan)
|
go StreamFromWebsocketToWriter(websocketConn, conn, errorChan)
|
||||||
go writeTCPToWebSocket(websocketConn, conn, errorChan)
|
go WriteReaderToWebSocket(websocketConn, &mu, conn, errorChan)
|
||||||
|
|
||||||
err = <-errorChan
|
err = <-errorChan
|
||||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
|
||||||
log.Debug().Msgf("Unexpected close error: %v\n", err)
|
log.Debug().Err(err).Msg("unexpected close error")
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msgf("session ended")
|
log.Info().Msg("session ended")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,60 +68,40 @@ func sendHTTPRequest(conn net.Conn, req *http.Request) (*http.Response, error) {
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func readWebSocketToTCP(websocketConn *websocket.Conn, tcpConn net.Conn, errorChan chan error) {
|
func WriteReaderToWebSocket(websocketConn *websocket.Conn, mu *sync.Mutex, reader io.Reader, errorChan chan error) {
|
||||||
for {
|
out := make([]byte, ReaderBufferSize)
|
||||||
messageType, p, err := websocketConn.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
|
|
||||||
log.Debug().Msgf("Unexpected close error: %v\n", err)
|
|
||||||
}
|
|
||||||
errorChan <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if messageType == websocket.TextMessage || messageType == websocket.BinaryMessage {
|
|
||||||
_, err := tcpConn.Write(p)
|
|
||||||
if err != nil {
|
|
||||||
log.Debug().Msgf("Error writing to TCP connection: %v\n", err)
|
|
||||||
errorChan <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeTCPToWebSocket(websocketConn *websocket.Conn, tcpConn net.Conn, errorChan chan error) {
|
|
||||||
var mu sync.Mutex
|
|
||||||
out := make([]byte, readerBufferSize)
|
|
||||||
input := make(chan string)
|
input := make(chan string)
|
||||||
pingTicker := time.NewTicker(pingPeriod)
|
pingTicker := time.NewTicker(PingPeriod)
|
||||||
defer pingTicker.Stop()
|
defer pingTicker.Stop()
|
||||||
defer websocketConn.Close()
|
defer websocketConn.Close()
|
||||||
|
|
||||||
websocketConn.SetReadLimit(2048)
|
mu.Lock()
|
||||||
|
websocketConn.SetReadLimit(ReaderBufferSize)
|
||||||
websocketConn.SetPongHandler(func(string) error {
|
websocketConn.SetPongHandler(func(string) error {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
websocketConn.SetPingHandler(func(data string) error {
|
websocketConn.SetPingHandler(func(data string) error {
|
||||||
websocketConn.SetWriteDeadline(time.Now().Add(writeWait))
|
websocketConn.SetWriteDeadline(time.Now().Add(WriteWait))
|
||||||
|
|
||||||
return websocketConn.WriteMessage(websocket.PongMessage, []byte(data))
|
return websocketConn.WriteMessage(websocket.PongMessage, []byte(data))
|
||||||
})
|
})
|
||||||
|
mu.Unlock()
|
||||||
reader := bufio.NewReader(tcpConn)
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
n, err := reader.Read(out)
|
n, err := reader.Read(out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errorChan <- err
|
errorChan <- err
|
||||||
|
|
||||||
if !errors.Is(err, io.EOF) {
|
if !errors.Is(err, io.EOF) {
|
||||||
log.Debug().Msgf("error reading from server: %v", err)
|
log.Debug().Err(err).Msg("error reading from server")
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
processedOutput := validString(string(out[:n]))
|
processedOutput := ValidString(string(out[:n]))
|
||||||
input <- processedOutput
|
input <- processedOutput
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -130,34 +109,37 @@ func writeTCPToWebSocket(websocketConn *websocket.Conn, tcpConn net.Conn, errorC
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case msg := <-input:
|
case msg := <-input:
|
||||||
err := wswrite(websocketConn, &mu, msg)
|
if err := wsWrite(websocketConn, mu, msg); err != nil {
|
||||||
if err != nil {
|
log.Debug().Err(err).Msg("error writing to websocket")
|
||||||
log.Debug().Msgf("error writing to websocket: %v", err)
|
|
||||||
errorChan <- err
|
errorChan <- err
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case <-pingTicker.C:
|
case <-pingTicker.C:
|
||||||
if err := wsping(websocketConn, &mu); err != nil {
|
if err := wsPing(websocketConn, mu); err != nil {
|
||||||
log.Debug().Msgf("error writing to websocket during pong response: %v", err)
|
log.Debug().Err(err).Msg("error writing to websocket during pong response")
|
||||||
errorChan <- err
|
errorChan <- err
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func wswrite(websocketConn *websocket.Conn, mu *sync.Mutex, msg string) error {
|
func wsWrite(websocketConn *websocket.Conn, mu *sync.Mutex, msg string) error {
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
defer mu.Unlock()
|
defer mu.Unlock()
|
||||||
|
|
||||||
websocketConn.SetWriteDeadline(time.Now().Add(writeWait))
|
websocketConn.SetWriteDeadline(time.Now().Add(WriteWait))
|
||||||
|
|
||||||
return websocketConn.WriteMessage(websocket.TextMessage, []byte(msg))
|
return websocketConn.WriteMessage(websocket.TextMessage, []byte(msg))
|
||||||
}
|
}
|
||||||
|
|
||||||
func wsping(websocketConn *websocket.Conn, mu *sync.Mutex) error {
|
func wsPing(websocketConn *websocket.Conn, mu *sync.Mutex) error {
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
defer mu.Unlock()
|
defer mu.Unlock()
|
||||||
|
|
||||||
websocketConn.SetWriteDeadline(time.Now().Add(writeWait))
|
websocketConn.SetWriteDeadline(time.Now().Add(WriteWait))
|
||||||
|
|
||||||
return websocketConn.WriteMessage(websocket.PingMessage, nil)
|
return websocketConn.WriteMessage(websocket.PingMessage, nil)
|
||||||
}
|
}
|
|
@ -0,0 +1,77 @@
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
const ReaderBufferSize = 2048
|
||||||
|
|
||||||
|
func StreamFromWebsocketToWriter(websocketConn *websocket.Conn, writer io.Writer, errorChan chan error) {
|
||||||
|
for {
|
||||||
|
messageType, in, err := websocketConn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
|
||||||
|
log.Debug().Err(err).Msg("unexpected close error")
|
||||||
|
}
|
||||||
|
errorChan <- err
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if messageType != websocket.TextMessage && messageType != websocket.BinaryMessage {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := writer.Write(in); err != nil {
|
||||||
|
log.Debug().Err(err).Msg("writing error")
|
||||||
|
errorChan <- err
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func StreamFromReaderToWebsocket(websocketConn *websocket.Conn, reader io.Reader, errorChan chan error) {
|
||||||
|
out := make([]byte, ReaderBufferSize)
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, err := reader.Read(out)
|
||||||
|
if err != nil {
|
||||||
|
errorChan <- err
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
processedOutput := ValidString(string(out[:n]))
|
||||||
|
if err := websocketConn.WriteMessage(websocket.TextMessage, []byte(processedOutput)); err != nil {
|
||||||
|
errorChan <- err
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidString(s string) string {
|
||||||
|
if utf8.ValidString(s) {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
v := make([]rune, 0, len(s))
|
||||||
|
|
||||||
|
for i, r := range s {
|
||||||
|
if r == utf8.RuneError {
|
||||||
|
_, size := utf8.DecodeRuneInString(s[i:])
|
||||||
|
if size == 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
v = append(v, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(v)
|
||||||
|
}
|
Loading…
Reference in New Issue