package snapshot

import (
	"bytes"
	"crypto/rand"
	"fmt"
	"io"
	"os"
	"path/filepath"
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/hashicorp/consul/agent/structs"
	"github.com/hashicorp/consul/sdk/testutil"
	"github.com/hashicorp/go-msgpack/codec"
	"github.com/hashicorp/raft"
	"github.com/stretchr/testify/require"
)

// MockFSM is a simple FSM for testing that simply stores its logs in a slice of
// byte slices.
type MockFSM struct {
	sync.Mutex
	logs [][]byte
}

// MockSnapshot is a snapshot sink for testing that encodes the contents of a
// MockFSM using msgpack.
type MockSnapshot struct {
	logs     [][]byte
	maxIndex int
}

// See raft.FSM.
func (m *MockFSM) Apply(log *raft.Log) interface{} {
	m.Lock()
	defer m.Unlock()
	m.logs = append(m.logs, log.Data)
	return len(m.logs)
}

// See raft.FSM.
func (m *MockFSM) Snapshot() (raft.FSMSnapshot, error) {
	m.Lock()
	defer m.Unlock()
	return &MockSnapshot{m.logs, len(m.logs)}, nil
}

// See raft.FSM.
func (m *MockFSM) Restore(in io.ReadCloser) error {
	m.Lock()
	defer m.Unlock()
	defer in.Close()
	dec := codec.NewDecoder(in, structs.MsgpackHandle)

	m.logs = nil
	return dec.Decode(&m.logs)
}

// See raft.SnapshotSink.
func (m *MockSnapshot) Persist(sink raft.SnapshotSink) error {
	enc := codec.NewEncoder(sink, structs.MsgpackHandle)
	if err := enc.Encode(m.logs[:m.maxIndex]); err != nil {
		sink.Cancel()
		return err
	}
	sink.Close()
	return nil
}

// See raft.SnapshotSink.
func (m *MockSnapshot) Release() {
}

// makeRaft returns a Raft and its FSM, with snapshots based in the given dir.
func makeRaft(t *testing.T, dir string) (*raft.Raft, *MockFSM) {
	snaps, err := raft.NewFileSnapshotStore(dir, 5, nil)
	if err != nil {
		t.Fatalf("err: %v", err)
	}

	fsm := &MockFSM{}
	store := raft.NewInmemStore()
	addr, trans := raft.NewInmemTransport("")

	config := raft.DefaultConfig()
	config.LocalID = raft.ServerID(fmt.Sprintf("server-%s", addr))

	var members raft.Configuration
	members.Servers = append(members.Servers, raft.Server{
		Suffrage: raft.Voter,
		ID:       config.LocalID,
		Address:  addr,
	})

	err = raft.BootstrapCluster(config, store, store, snaps, trans, members)
	if err != nil {
		t.Fatalf("err: %v", err)
	}

	raft, err := raft.NewRaft(config, fsm, store, store, snaps, trans)
	if err != nil {
		t.Fatalf("err: %v", err)
	}

	timeout := time.After(10 * time.Second)
	for {
		if raft.Leader() != "" {
			break
		}

		select {
		case <-raft.LeaderCh():
		case <-time.After(1 * time.Second):
			// Need to poll because we might have missed the first
			// go with the leader channel.
		case <-timeout:
			t.Fatalf("timed out waiting for leader")
		}
	}

	return raft, fsm
}

func TestSnapshot(t *testing.T) {
	dir := testutil.TempDir(t, "snapshot")
	defer os.RemoveAll(dir)

	// Make a Raft and populate it with some data. We tee everything we
	// apply off to a buffer for checking post-snapshot.
	var expected []bytes.Buffer
	entries := 64 * 1024
	before, _ := makeRaft(t, filepath.Join(dir, "before"))
	defer before.Shutdown()
	for i := 0; i < entries; i++ {
		var log bytes.Buffer
		var copy bytes.Buffer
		both := io.MultiWriter(&log, &copy)
		if _, err := io.CopyN(both, rand.Reader, 256); err != nil {
			t.Fatalf("err: %v", err)
		}
		future := before.Apply(log.Bytes(), time.Second)
		if err := future.Error(); err != nil {
			t.Fatalf("err: %v", err)
		}
		expected = append(expected, copy)
	}

	// Take a snapshot.
	logger := testutil.Logger(t)
	snap, err := New(logger, before)
	if err != nil {
		t.Fatalf("err: %v", err)
	}
	defer snap.Close()

	// Verify the snapshot. We have to rewind it after for the restore.
	metadata, err := Verify(snap)
	if err != nil {
		t.Fatalf("err: %v", err)
	}
	if _, err := snap.file.Seek(0, 0); err != nil {
		t.Fatalf("err: %v", err)
	}
	if int(metadata.Index) != entries+2 {
		t.Fatalf("bad: %d", metadata.Index)
	}
	if metadata.Term != 2 {
		t.Fatalf("bad: %d", metadata.Index)
	}
	if metadata.Version != raft.SnapshotVersionMax {
		t.Fatalf("bad: %d", metadata.Version)
	}

	// Make a new, independent Raft.
	after, fsm := makeRaft(t, filepath.Join(dir, "after"))
	defer after.Shutdown()

	// Put some initial data in there that the snapshot should overwrite.
	for i := 0; i < 16; i++ {
		var log bytes.Buffer
		if _, err := io.CopyN(&log, rand.Reader, 256); err != nil {
			t.Fatalf("err: %v", err)
		}
		future := after.Apply(log.Bytes(), time.Second)
		if err := future.Error(); err != nil {
			t.Fatalf("err: %v", err)
		}
	}

	// Restore the snapshot.
	if err := Restore(logger, snap, after); err != nil {
		t.Fatalf("err: %v", err)
	}

	// Compare the contents.
	fsm.Lock()
	defer fsm.Unlock()
	if len(fsm.logs) != len(expected) {
		t.Fatalf("bad: %d vs. %d", len(fsm.logs), len(expected))
	}
	for i := range fsm.logs {
		if !bytes.Equal(fsm.logs[i], expected[i].Bytes()) {
			t.Fatalf("bad: log %d doesn't match", i)
		}
	}
}

func TestSnapshot_Nil(t *testing.T) {
	var snap *Snapshot

	if idx := snap.Index(); idx != 0 {
		t.Fatalf("bad: %d", idx)
	}

	n, err := snap.Read(make([]byte, 16))
	if n != 0 || err != io.EOF {
		t.Fatalf("bad: %d %v", n, err)
	}

	if err := snap.Close(); err != nil {
		t.Fatalf("err: %v", err)
	}
}

func TestSnapshot_BadVerify(t *testing.T) {
	buf := bytes.NewBuffer([]byte("nope"))
	_, err := Verify(buf)
	if err == nil || !strings.Contains(err.Error(), "unexpected EOF") {
		t.Fatalf("err: %v", err)
	}
}

func TestSnapshot_TruncatedVerify(t *testing.T) {
	dir := testutil.TempDir(t, "snapshot")
	defer os.RemoveAll(dir)

	// Make a Raft and populate it with some data. We tee everything we
	// apply off to a buffer for checking post-snapshot.
	entries := 64 * 1024
	before, _ := makeRaft(t, filepath.Join(dir, "before"))
	defer before.Shutdown()
	for i := 0; i < entries; i++ {
		var log bytes.Buffer
		var copy bytes.Buffer
		both := io.MultiWriter(&log, &copy)

		_, err := io.CopyN(both, rand.Reader, 256)
		require.NoError(t, err)

		future := before.Apply(log.Bytes(), time.Second)
		require.NoError(t, future.Error())
	}

	// Take a snapshot.
	logger := testutil.Logger(t)
	snap, err := New(logger, before)
	require.NoError(t, err)
	defer snap.Close()

	var data []byte
	{
		var buf bytes.Buffer
		_, err = io.Copy(&buf, snap)
		require.NoError(t, err)
		data = buf.Bytes()
	}

	for _, removeBytes := range []int{200, 16, 8, 4, 2, 1} {
		t.Run(fmt.Sprintf("truncate %d bytes from end", removeBytes), func(t *testing.T) {
			// Lop off part of the end.
			buf := bytes.NewReader(data[0 : len(data)-removeBytes])

			_, err = Verify(buf)
			require.Error(t, err)
		})
	}
}

func TestSnapshot_BadRestore(t *testing.T) {
	dir := testutil.TempDir(t, "snapshot")
	defer os.RemoveAll(dir)

	// Make a Raft and populate it with some data.
	before, _ := makeRaft(t, filepath.Join(dir, "before"))
	defer before.Shutdown()
	for i := 0; i < 16*1024; i++ {
		var log bytes.Buffer
		if _, err := io.CopyN(&log, rand.Reader, 256); err != nil {
			t.Fatalf("err: %v", err)
		}
		future := before.Apply(log.Bytes(), time.Second)
		if err := future.Error(); err != nil {
			t.Fatalf("err: %v", err)
		}
	}

	// Take a snapshot.
	logger := testutil.Logger(t)
	snap, err := New(logger, before)
	if err != nil {
		t.Fatalf("err: %v", err)
	}

	// Make a new, independent Raft.
	after, fsm := makeRaft(t, filepath.Join(dir, "after"))
	defer after.Shutdown()

	// Put some initial data in there that should not be harmed by the
	// failed restore attempt.
	var expected []bytes.Buffer
	for i := 0; i < 16; i++ {
		var log bytes.Buffer
		var copy bytes.Buffer
		both := io.MultiWriter(&log, &copy)
		if _, err := io.CopyN(both, rand.Reader, 256); err != nil {
			t.Fatalf("err: %v", err)
		}
		future := after.Apply(log.Bytes(), time.Second)
		if err := future.Error(); err != nil {
			t.Fatalf("err: %v", err)
		}
		expected = append(expected, copy)
	}

	// Attempt to restore a truncated version of the snapshot. This is
	// expected to fail.
	err = Restore(logger, io.LimitReader(snap, 512), after)
	if err == nil || !strings.Contains(err.Error(), "unexpected EOF") {
		t.Fatalf("err: %v", err)
	}

	// Compare the contents to make sure the aborted restore didn't harm
	// anything.
	fsm.Lock()
	defer fsm.Unlock()
	if len(fsm.logs) != len(expected) {
		t.Fatalf("bad: %d vs. %d", len(fsm.logs), len(expected))
	}
	for i := range fsm.logs {
		if !bytes.Equal(fsm.logs[i], expected[i].Bytes()) {
			t.Fatalf("bad: log %d doesn't match", i)
		}
	}
}