mirror of https://github.com/portainer/portainer
199 lines
4.4 KiB
Go
199 lines
4.4 KiB
Go
package handler
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"os"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/gorilla/mux"
|
|
"github.com/portainer/portainer"
|
|
"github.com/portainer/portainer/crypto"
|
|
"golang.org/x/net/websocket"
|
|
)
|
|
|
|
// WebSocketHandler represents an HTTP API handler for proxying requests to a web socket.
|
|
type WebSocketHandler struct {
|
|
*mux.Router
|
|
Logger *log.Logger
|
|
EndpointService portainer.EndpointService
|
|
}
|
|
|
|
// NewWebSocketHandler returns a new instance of WebSocketHandler.
|
|
func NewWebSocketHandler() *WebSocketHandler {
|
|
h := &WebSocketHandler{
|
|
Router: mux.NewRouter(),
|
|
Logger: log.New(os.Stderr, "", log.LstdFlags),
|
|
}
|
|
h.Handle("/websocket/exec", websocket.Handler(h.webSocketDockerExec))
|
|
return h
|
|
}
|
|
|
|
func (handler *WebSocketHandler) webSocketDockerExec(ws *websocket.Conn) {
|
|
qry := ws.Request().URL.Query()
|
|
execID := qry.Get("id")
|
|
edpID := qry.Get("endpointId")
|
|
|
|
parsedID, err := strconv.Atoi(edpID)
|
|
if err != nil {
|
|
log.Printf("Unable to parse endpoint ID: %s", err)
|
|
return
|
|
}
|
|
|
|
endpointID := portainer.EndpointID(parsedID)
|
|
endpoint, err := handler.EndpointService.Endpoint(endpointID)
|
|
if err != nil {
|
|
log.Printf("Unable to retrieve endpoint: %s", err)
|
|
return
|
|
}
|
|
|
|
endpointURL, err := url.Parse(endpoint.URL)
|
|
if err != nil {
|
|
log.Printf("Unable to parse endpoint URL: %s", err)
|
|
return
|
|
}
|
|
|
|
var host string
|
|
if endpointURL.Scheme == "tcp" {
|
|
host = endpointURL.Host
|
|
} else if endpointURL.Scheme == "unix" {
|
|
host = endpointURL.Path
|
|
}
|
|
|
|
// Should not be managed here
|
|
var tlsConfig *tls.Config
|
|
if endpoint.TLS {
|
|
tlsConfig, err = crypto.CreateTLSConfiguration(endpoint.TLSCACertPath,
|
|
endpoint.TLSCertPath,
|
|
endpoint.TLSKeyPath)
|
|
if err != nil {
|
|
log.Fatalf("Unable to create TLS configuration: %s", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
if err := hijack(host, endpointURL.Scheme, "POST", "/exec/"+execID+"/start", tlsConfig, true, ws, ws, ws, nil, nil); err != nil {
|
|
log.Fatalf("error during hijack: %s", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
type execConfig struct {
|
|
Tty bool
|
|
Detach bool
|
|
}
|
|
|
|
// hijack allows to upgrade an HTTP connection to a TCP connection
|
|
// It redirects IO streams for stdin, stdout and stderr to a websocket
|
|
func hijack(addr, scheme, method, path string, tlsConfig *tls.Config, setRawTerminal bool, in io.ReadCloser, stdout, stderr io.Writer, started chan io.Closer, data interface{}) error {
|
|
execConfig := &execConfig{
|
|
Tty: true,
|
|
Detach: false,
|
|
}
|
|
|
|
buf, err := json.Marshal(execConfig)
|
|
if err != nil {
|
|
return fmt.Errorf("error marshaling exec config: %s", err)
|
|
}
|
|
|
|
rdr := bytes.NewReader(buf)
|
|
|
|
req, err := http.NewRequest(method, path, rdr)
|
|
if err != nil {
|
|
return fmt.Errorf("error during hijack request: %s", err)
|
|
}
|
|
|
|
req.Header.Set("User-Agent", "Docker-Client")
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Connection", "Upgrade")
|
|
req.Header.Set("Upgrade", "tcp")
|
|
req.Host = addr
|
|
|
|
var (
|
|
dial net.Conn
|
|
dialErr error
|
|
)
|
|
|
|
if tlsConfig == nil {
|
|
dial, dialErr = net.Dial(scheme, addr)
|
|
} else {
|
|
dial, dialErr = tls.Dial(scheme, addr, tlsConfig)
|
|
}
|
|
|
|
if dialErr != nil {
|
|
return dialErr
|
|
}
|
|
|
|
// When we set up a TCP connection for hijack, there could be long periods
|
|
// of inactivity (a long running command with no output) that in certain
|
|
// network setups may cause ECONNTIMEOUT, leaving the client in an unknown
|
|
// state. Setting TCP KeepAlive on the socket connection will prohibit
|
|
// ECONNTIMEOUT unless the socket connection truly is broken
|
|
if tcpConn, ok := dial.(*net.TCPConn); ok {
|
|
tcpConn.SetKeepAlive(true)
|
|
tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
clientconn := httputil.NewClientConn(dial, nil)
|
|
defer clientconn.Close()
|
|
|
|
// Server hijacks the connection, error 'connection closed' expected
|
|
clientconn.Do(req)
|
|
|
|
rwc, br := clientconn.Hijack()
|
|
defer rwc.Close()
|
|
|
|
if started != nil {
|
|
started <- rwc
|
|
}
|
|
|
|
var receiveStdout chan error
|
|
|
|
if stdout != nil || stderr != nil {
|
|
go func() (err error) {
|
|
if setRawTerminal && stdout != nil {
|
|
_, err = io.Copy(stdout, br)
|
|
}
|
|
return err
|
|
}()
|
|
}
|
|
|
|
go func() error {
|
|
if in != nil {
|
|
io.Copy(rwc, in)
|
|
}
|
|
|
|
if conn, ok := rwc.(interface {
|
|
CloseWrite() error
|
|
}); ok {
|
|
if err := conn.CloseWrite(); err != nil {
|
|
}
|
|
}
|
|
return nil
|
|
}()
|
|
|
|
if stdout != nil || stderr != nil {
|
|
if err := <-receiveStdout; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
go func() {
|
|
for {
|
|
fmt.Println(br)
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|