package handler

import (
	"bytes"
	"crypto/tls"
	"encoding/json"
	"fmt"
	"io"
	"log"
	"net"
	"net/http"
	"net/http/httputil"
	"net/url"
	"os"
	"strconv"
	"time"

	"github.com/gorilla/mux"
	"github.com/portainer/portainer"
	"github.com/portainer/portainer/crypto"
	"golang.org/x/net/websocket"
)

// WebSocketHandler represents an HTTP API handler for proxying requests to a web socket.
type WebSocketHandler struct {
	*mux.Router
	Logger          *log.Logger
	EndpointService portainer.EndpointService
}

// NewWebSocketHandler returns a new instance of WebSocketHandler.
func NewWebSocketHandler() *WebSocketHandler {
	h := &WebSocketHandler{
		Router: mux.NewRouter(),
		Logger: log.New(os.Stderr, "", log.LstdFlags),
	}
	h.Handle("/websocket/exec", websocket.Handler(h.webSocketDockerExec))
	return h
}

func (handler *WebSocketHandler) webSocketDockerExec(ws *websocket.Conn) {
	qry := ws.Request().URL.Query()
	execID := qry.Get("id")
	edpID := qry.Get("endpointId")

	parsedID, err := strconv.Atoi(edpID)
	if err != nil {
		log.Printf("Unable to parse endpoint ID: %s", err)
		return
	}

	endpointID := portainer.EndpointID(parsedID)
	endpoint, err := handler.EndpointService.Endpoint(endpointID)
	if err != nil {
		log.Printf("Unable to retrieve endpoint: %s", err)
		return
	}

	endpointURL, err := url.Parse(endpoint.URL)
	if err != nil {
		log.Printf("Unable to parse endpoint URL: %s", err)
		return
	}

	var host string
	if endpointURL.Scheme == "tcp" {
		host = endpointURL.Host
	} else if endpointURL.Scheme == "unix" {
		host = endpointURL.Path
	}

	// TODO: Should not be managed here
	var tlsConfig *tls.Config
	if endpoint.TLSConfig.TLS {
		tlsConfig, err = crypto.CreateTLSConfiguration(&endpoint.TLSConfig)
		if err != nil {
			log.Fatalf("Unable to create TLS configuration: %s", err)
			return
		}
	}

	if err := hijack(host, endpointURL.Scheme, "POST", "/exec/"+execID+"/start", tlsConfig, true, ws, ws, ws, nil, nil); err != nil {
		log.Fatalf("error during hijack: %s", err)
		return
	}
}

type execConfig struct {
	Tty    bool
	Detach bool
}

// hijack allows to upgrade an HTTP connection to a TCP connection
// It redirects IO streams for stdin, stdout and stderr to a websocket
func hijack(addr, scheme, method, path string, tlsConfig *tls.Config, setRawTerminal bool, in io.ReadCloser, stdout, stderr io.Writer, started chan io.Closer, data interface{}) error {
	execConfig := &execConfig{
		Tty:    true,
		Detach: false,
	}

	buf, err := json.Marshal(execConfig)
	if err != nil {
		return fmt.Errorf("error marshaling exec config: %s", err)
	}

	rdr := bytes.NewReader(buf)

	req, err := http.NewRequest(method, path, rdr)
	if err != nil {
		return fmt.Errorf("error during hijack request: %s", err)
	}

	req.Header.Set("User-Agent", "Docker-Client")
	req.Header.Set("Content-Type", "application/json")
	req.Header.Set("Connection", "Upgrade")
	req.Header.Set("Upgrade", "tcp")
	req.Host = addr

	var (
		dial    net.Conn
		dialErr error
	)

	if tlsConfig == nil {
		dial, dialErr = net.Dial(scheme, addr)
	} else {
		dial, dialErr = tls.Dial(scheme, addr, tlsConfig)
	}

	if dialErr != nil {
		return dialErr
	}

	// When we set up a TCP connection for hijack, there could be long periods
	// of inactivity (a long running command with no output) that in certain
	// network setups may cause ECONNTIMEOUT, leaving the client in an unknown
	// state. Setting TCP KeepAlive on the socket connection will prohibit
	// ECONNTIMEOUT unless the socket connection truly is broken
	if tcpConn, ok := dial.(*net.TCPConn); ok {
		tcpConn.SetKeepAlive(true)
		tcpConn.SetKeepAlivePeriod(30 * time.Second)
	}
	if err != nil {
		return err
	}
	clientconn := httputil.NewClientConn(dial, nil)
	defer clientconn.Close()

	// Server hijacks the connection, error 'connection closed' expected
	clientconn.Do(req)

	rwc, br := clientconn.Hijack()
	defer rwc.Close()

	if started != nil {
		started <- rwc
	}

	var receiveStdout chan error

	if stdout != nil || stderr != nil {
		go func() (err error) {
			if setRawTerminal && stdout != nil {
				_, err = io.Copy(stdout, br)
			}
			return err
		}()
	}

	go func() error {
		if in != nil {
			io.Copy(rwc, in)
		}

		if conn, ok := rwc.(interface {
			CloseWrite() error
		}); ok {
			if err := conn.CloseWrite(); err != nil {
			}
		}
		return nil
	}()

	if stdout != nil || stderr != nil {
		if err := <-receiveStdout; err != nil {
			return err
		}
	}
	go func() {
		for {
			fmt.Println(br)
		}
	}()

	return nil
}