// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package connect

import (
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"io"
	"log"
	"net"
	"net/http"
	"sync/atomic"

	"github.com/hashicorp/go-hclog"
	testing "github.com/mitchellh/go-testing-interface"

	"github.com/hashicorp/consul/agent/connect"
	"github.com/hashicorp/consul/agent/structs"
	"github.com/hashicorp/consul/sdk/freeport"
)

// TestService returns a Service instance based on a static TLS Config.
func TestService(t testing.T, service string, ca *structs.CARoot) *Service {
	t.Helper()

	// Don't need to talk to client since we are setting TLSConfig locally
	logger := hclog.New(&hclog.LoggerOptions{})
	svc, err := NewDevServiceWithTLSConfig(service,
		logger, TestTLSConfig(t, service, ca))
	if err != nil {
		t.Fatal(err)
	}
	return svc
}

// TestTLSConfig returns a *tls.Config suitable for use during tests.
func TestTLSConfig(t testing.T, service string, ca *structs.CARoot) *tls.Config {
	t.Helper()

	cfg := defaultTLSConfig()
	cfg.Certificates = []tls.Certificate{TestSvcKeyPair(t, service, ca)}
	cfg.RootCAs = TestCAPool(t, ca)
	cfg.ClientCAs = TestCAPool(t, ca)
	return cfg
}

// TestCAPool returns an *x509.CertPool containing the passed CA's root(s)
func TestCAPool(t testing.T, cas ...*structs.CARoot) *x509.CertPool {
	t.Helper()
	pool := x509.NewCertPool()
	for _, ca := range cas {
		pool.AppendCertsFromPEM([]byte(ca.RootCert))
	}
	return pool
}

// TestSvcKeyPair returns an tls.Certificate containing both cert and private
// key for a given service under a given CA from the testdata dir.
func TestSvcKeyPair(t testing.T, service string, ca *structs.CARoot) tls.Certificate {
	t.Helper()
	certPEM, keyPEM := connect.TestLeaf(t, service, ca)
	cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
	if err != nil {
		t.Fatal(err)
	}
	return cert
}

// TestPeerCertificates returns a []*x509.Certificate as you'd get from
// tls.Conn.ConnectionState().PeerCertificates including the named certificate.
func TestPeerCertificates(t testing.T, service string, ca *structs.CARoot) []*x509.Certificate {
	t.Helper()
	certPEM, _ := connect.TestLeaf(t, service, ca)
	cert, err := connect.ParseCert(certPEM)
	if err != nil {
		t.Fatal(err)
	}
	return []*x509.Certificate{cert}
}

// TestServer runs a service listener that can be used to test clients. It's
// behavior can be controlled by the struct members.
type TestServer struct {
	// The service name to serve.
	Service string
	// The (test) CA to use for generating certs.
	CA *structs.CARoot
	// TimeoutHandshake controls whether the listening server will complete a TLS
	// handshake quickly enough.
	TimeoutHandshake bool
	// TLSCfg is the tls.Config that will be used. By default it's set up from the
	// service and ca set.
	TLSCfg *tls.Config
	// Addr is the listen address. It is set to a random free port on `localhost`
	// by default.
	Addr string
	// Listening is closed when the listener is run.
	Listening chan struct{}

	l        net.Listener
	stopFlag int32
	stopChan chan struct{}
}

// NewTestServer returns a TestServer. It should be closed when test is
// complete.
func NewTestServer(t testing.T, service string, ca *structs.CARoot) *TestServer {
	return &TestServer{
		Service:   service,
		CA:        ca,
		stopChan:  make(chan struct{}),
		TLSCfg:    TestTLSConfig(t, service, ca),
		Addr:      fmt.Sprintf("127.0.0.1:%d", freeport.GetOne(t)),
		Listening: make(chan struct{}),
	}
}

// Serve runs a tcp echo server and blocks until it is closed or errors. If
// TimeoutHandshake is set it won't start TLS handshake on new connections.
func (s *TestServer) Serve() error {
	// Just accept TCP conn but so we can control timing of accept/handshake
	l, err := net.Listen("tcp", s.Addr)
	if err != nil {
		return err
	}
	s.l = l
	close(s.Listening)
	log.Printf("test connect service listening on %s", s.Addr)

	for {
		conn, err := s.l.Accept()
		if err != nil {
			if atomic.LoadInt32(&s.stopFlag) == 1 {
				return nil
			}
			return err
		}

		// Ignore the conn if we are not actively handshaking
		if !s.TimeoutHandshake {
			// Upgrade conn to TLS
			conn = tls.Server(conn, s.TLSCfg)

			// Run an echo service
			log.Printf("test connect service accepted conn from %s, "+
				" running echo service", conn.RemoteAddr())
			go io.Copy(conn, conn)
		}

		// Close this conn when we stop
		go func(c net.Conn) {
			<-s.stopChan
			c.Close()
		}(conn)
	}
}

// ServeHTTPS runs an HTTPS server with the given config. It invokes the passed
// Handler for all requests.
func (s *TestServer) ServeHTTPS(h http.Handler) error {
	srv := http.Server{
		Addr:      s.Addr,
		TLSConfig: s.TLSCfg,
		Handler:   h,
	}
	log.Printf("starting test connect HTTPS server on %s", s.Addr)

	// Use our own listener so we can signal when it's ready.
	l, err := net.Listen("tcp", s.Addr)
	if err != nil {
		return err
	}
	close(s.Listening)
	s.l = l
	log.Printf("test connect service listening on %s", s.Addr)

	err = srv.ServeTLS(l, "", "")
	if atomic.LoadInt32(&s.stopFlag) == 1 {
		return nil
	}
	return err
}

// Close stops a TestServer
func (s *TestServer) Close() error {
	old := atomic.SwapInt32(&s.stopFlag, 1)
	if old == 0 {
		if s.l != nil {
			s.l.Close()
		}
		close(s.stopChan)
	}
	return nil
}