You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
portainer/api/ws/hijack.go

146 lines
3.4 KiB

package ws
import (
"bufio"
"errors"
"fmt"
"io"
"net"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/rs/zerolog/log"
)
const (
// Time allowed to write a message to the peer
WriteWait = 10 * time.Second
// Send pings to peer with this period
PingPeriod = 50 * time.Second
)
func HijackRequest(websocketConn *websocket.Conn, conn net.Conn, request *http.Request) error {
resp, err := sendHTTPRequest(conn, request)
if err != nil {
return err
}
defer resp.Body.Close()
// Check if the response status code indicates an upgrade (101 Switching Protocols)
if resp.StatusCode != http.StatusSwitchingProtocols {
return fmt.Errorf("unexpected response status code: %d", resp.StatusCode)
}
var mu sync.Mutex
errorChan := make(chan error, 1)
go StreamFromWebsocketToWriter(websocketConn, conn, errorChan)
go WriteReaderToWebSocket(websocketConn, &mu, conn, errorChan)
err = <-errorChan
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
log.Debug().Err(err).Msg("unexpected close error")
return err
}
log.Info().Msg("session ended")
return nil
}
// sendHTTPRequest sends an HTTP request over the provided net.Conn and parses the response.
func sendHTTPRequest(conn net.Conn, req *http.Request) (*http.Response, error) {
// Send the HTTP request to the server
if err := req.Write(conn); err != nil {
return nil, fmt.Errorf("error writing request: %w", err)
}
// Read the response from the server
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
return nil, fmt.Errorf("error reading response: %w", err)
}
return resp, nil
}
func WriteReaderToWebSocket(websocketConn *websocket.Conn, mu *sync.Mutex, reader io.Reader, errorChan chan error) {
out := make([]byte, ReaderBufferSize)
input := make(chan string)
pingTicker := time.NewTicker(PingPeriod)
defer pingTicker.Stop()
defer websocketConn.Close()
mu.Lock()
websocketConn.SetReadLimit(ReaderBufferSize)
websocketConn.SetPongHandler(func(string) error {
return nil
})
websocketConn.SetPingHandler(func(data string) error {
websocketConn.SetWriteDeadline(time.Now().Add(WriteWait))
return websocketConn.WriteMessage(websocket.PongMessage, []byte(data))
})
mu.Unlock()
go func() {
for {
n, err := reader.Read(out)
if err != nil {
errorChan <- err
if !errors.Is(err, io.EOF) {
log.Debug().Err(err).Msg("error reading from server")
}
return
}
processedOutput := ValidString(string(out[:n]))
input <- processedOutput
}
}()
for {
select {
case msg := <-input:
if err := wsWrite(websocketConn, mu, msg); err != nil {
log.Debug().Err(err).Msg("error writing to websocket")
errorChan <- err
return
}
case <-pingTicker.C:
if err := wsPing(websocketConn, mu); err != nil {
log.Debug().Err(err).Msg("error writing to websocket during pong response")
errorChan <- err
return
}
}
}
}
func wsWrite(websocketConn *websocket.Conn, mu *sync.Mutex, msg string) error {
mu.Lock()
defer mu.Unlock()
websocketConn.SetWriteDeadline(time.Now().Add(WriteWait))
return websocketConn.WriteMessage(websocket.TextMessage, []byte(msg))
}
func wsPing(websocketConn *websocket.Conn, mu *sync.Mutex) error {
mu.Lock()
defer mu.Unlock()
websocketConn.SetWriteDeadline(time.Now().Add(WriteWait))
return websocketConn.WriteMessage(websocket.PingMessage, nil)
}