package raftchunking

import (
	"bytes"
	"crypto/rand"
	"encoding/binary"
	"fmt"
	"io"
	"time"

	proto "github.com/golang/protobuf/proto"
	"github.com/hashicorp/errwrap"
	"github.com/hashicorp/go-raftchunking/types"
	"github.com/hashicorp/raft"
)

// errorFuture is used to return a static error.
type errorFuture struct {
	err error
}

func (e errorFuture) Error() error {
	return e.err
}

func (e errorFuture) Response() interface{} {
	return nil
}

func (e errorFuture) Index() uint64 {
	return 0
}

// multiFuture is a future specialized for the chunking case. It contains some
// number of other futures in the order in which data was chunked and sent to
// apply.
type multiFuture []raft.ApplyFuture

// Error will return only when all Error functions in the contained futures
// return, in order.
func (m multiFuture) Error() error {
	for _, v := range m {
		if err := v.Error(); err != nil {
			return err
		}
	}

	return nil
}

// Index returns the index of the last chunk. Since required behavior is to not
// call this until Error is called, the last Index will correspond to the Apply
// of the final chunk.
func (m multiFuture) Index() uint64 {
	// This shouldn't happen but need an escape hatch
	if len(m) == 0 {
		return 0
	}

	return m[len(m)-1].Index()
}

// Response returns the response from underlying Apply of the last chunk.
func (m multiFuture) Response() interface{} {
	// This shouldn't happen but need an escape hatch
	if len(m) == 0 {
		return nil
	}

	return m[len(m)-1].Response()
}

type ApplyFunc func(raft.Log, time.Duration) raft.ApplyFuture

// ChunkingApply takes in a byte slice and chunks into
// raft.SuggestedMaxDataSize (or less if EOF) chunks, calling Apply on each. It
// requires a corresponding wrapper around the FSM to handle reconstructing on
// the other end. Timeout will be the timeout for each individual operation,
// not total. The return value is a future whose Error() will return only when
// all underlying Apply futures have had Error() return. Note that any error
// indicates that the entire operation will not be applied, assuming the
// correct FSM wrapper is used. If extensions is passed in, it will be set as
// the Extensions value on the Apply once all chunks are received.
func ChunkingApply(cmd, extensions []byte, timeout time.Duration, applyFunc ApplyFunc) raft.ApplyFuture {
	// Generate a random op num via 64 random bits. These only have to be
	// unique across _in flight_ chunk operations until a Term changes so
	// should be fine.
	rb := make([]byte, 8)
	n, err := rand.Read(rb)
	if err != nil {
		return errorFuture{err: err}
	}
	if n != 8 {
		return errorFuture{err: fmt.Errorf("expected to read %d bytes for op num, read %d", 8, n)}
	}
	opNum := binary.BigEndian.Uint64(rb)

	var logs []raft.Log
	var byteChunks [][]byte
	var mf multiFuture

	// We break into chunks first so that we know how many chunks there will be
	// to put in NumChunks in the extensions info. This could probably be a bit
	// more efficient by just reslicing but doing it this way is a bit easier
	// for others to follow/track and in this kind of operation this won't be
	// the slow part anyways.
	reader := bytes.NewReader(cmd)
	remain := reader.Len()
	for {
		if remain <= 0 {
			break
		}

		if remain > raft.SuggestedMaxDataSize {
			remain = raft.SuggestedMaxDataSize
		}

		b := make([]byte, remain)
		n, err := reader.Read(b)
		if err != nil && err != io.EOF {
			return errorFuture{err: err}
		}
		if n != remain {
			return errorFuture{err: fmt.Errorf("expected to read %d bytes from buf, read %d", remain, n)}
		}

		byteChunks = append(byteChunks, b)
		remain = reader.Len()
	}

	// Create the underlying chunked logs
	for i, chunk := range byteChunks {
		chunkInfo := &types.ChunkInfo{
			OpNum:       opNum,
			SequenceNum: uint32(i),
			NumChunks:   uint32(len(byteChunks)),
		}

		// If extensions were passed in attach them to the last chunk so it
		// will go through Apply at the end.
		if i == len(byteChunks)-1 {
			chunkInfo.NextExtensions = extensions
		}

		chunkBytes, err := proto.Marshal(chunkInfo)
		if err != nil {
			return errorFuture{err: errwrap.Wrapf("error marshaling chunk info: {{err}}", err)}
		}
		logs = append(logs, raft.Log{
			Data:       chunk,
			Extensions: chunkBytes,
		})
	}

	for _, log := range logs {
		mf = append(mf, applyFunc(log, timeout))
	}

	return mf
}