package ws

import (
	"bufio"
	"errors"
	"fmt"
	"io"
	"net"
	"net/http"
	"sync"
	"time"

	"github.com/gorilla/websocket"
	"github.com/rs/zerolog/log"
)

const (
	// Time allowed to write a message to the peer
	WriteWait = 10 * time.Second

	// Send pings to peer with this period
	PingPeriod = 50 * time.Second
)

func HijackRequest(websocketConn *websocket.Conn, conn net.Conn, request *http.Request) error {
	resp, err := sendHTTPRequest(conn, request)
	if err != nil {
		return err
	}
	defer resp.Body.Close()

	// Check if the response status code indicates an upgrade (101 Switching Protocols)
	if resp.StatusCode != http.StatusSwitchingProtocols {
		return fmt.Errorf("unexpected response status code: %d", resp.StatusCode)
	}

	var mu sync.Mutex

	errorChan := make(chan error, 1)
	go StreamFromWebsocketToWriter(websocketConn, conn, errorChan)
	go WriteReaderToWebSocket(websocketConn, &mu, conn, errorChan)

	err = <-errorChan
	if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
		log.Debug().Err(err).Msg("unexpected close error")

		return err
	}

	log.Info().Msg("session ended")

	return nil
}

// sendHTTPRequest sends an HTTP request over the provided net.Conn and parses the response.
func sendHTTPRequest(conn net.Conn, req *http.Request) (*http.Response, error) {
	// Send the HTTP request to the server
	if err := req.Write(conn); err != nil {
		return nil, fmt.Errorf("error writing request: %w", err)
	}

	// Read the response from the server
	resp, err := http.ReadResponse(bufio.NewReader(conn), req)
	if err != nil {
		return nil, fmt.Errorf("error reading response: %w", err)
	}

	return resp, nil
}

func WriteReaderToWebSocket(websocketConn *websocket.Conn, mu *sync.Mutex, reader io.Reader, errorChan chan error) {
	out := make([]byte, ReaderBufferSize)
	input := make(chan string)
	pingTicker := time.NewTicker(PingPeriod)
	defer pingTicker.Stop()
	defer websocketConn.Close()

	mu.Lock()
	websocketConn.SetReadLimit(ReaderBufferSize)
	websocketConn.SetPongHandler(func(string) error {
		return nil
	})

	websocketConn.SetPingHandler(func(data string) error {
		websocketConn.SetWriteDeadline(time.Now().Add(WriteWait))

		return websocketConn.WriteMessage(websocket.PongMessage, []byte(data))
	})
	mu.Unlock()

	go func() {
		for {
			n, err := reader.Read(out)
			if err != nil {
				errorChan <- err

				if !errors.Is(err, io.EOF) {
					log.Debug().Err(err).Msg("error reading from server")
				}

				return
			}

			processedOutput := ValidString(string(out[:n]))
			input <- processedOutput
		}
	}()

	for {
		select {
		case msg := <-input:
			if err := wsWrite(websocketConn, mu, msg); err != nil {
				log.Debug().Err(err).Msg("error writing to websocket")
				errorChan <- err

				return
			}
		case <-pingTicker.C:
			if err := wsPing(websocketConn, mu); err != nil {
				log.Debug().Err(err).Msg("error writing to websocket during pong response")
				errorChan <- err

				return
			}
		}
	}
}

func wsWrite(websocketConn *websocket.Conn, mu *sync.Mutex, msg string) error {
	mu.Lock()
	defer mu.Unlock()

	websocketConn.SetWriteDeadline(time.Now().Add(WriteWait))

	return websocketConn.WriteMessage(websocket.TextMessage, []byte(msg))
}

func wsPing(websocketConn *websocket.Conn, mu *sync.Mutex) error {
	mu.Lock()
	defer mu.Unlock()

	websocketConn.SetWriteDeadline(time.Now().Add(WriteWait))

	return websocketConn.WriteMessage(websocket.PingMessage, nil)
}