package boltdb

import (
	"bytes"
	"crypto/aes"
	"crypto/cipher"
	"crypto/rand"
	"io"

	"github.com/pkg/errors"
	"github.com/segmentio/encoding/json"
)

var errEncryptedStringTooShort = errors.New("encrypted string too short")

// MarshalObject encodes an object to binary format
func (connection *DbConnection) MarshalObject(object any) ([]byte, error) {
	buf := &bytes.Buffer{}

	// Special case for the VERSION bucket. Here we're not using json
	if v, ok := object.(string); ok {
		buf.WriteString(v)
	} else {
		enc := json.NewEncoder(buf)
		enc.SetSortMapKeys(false)
		enc.SetAppendNewline(false)

		if err := enc.Encode(object); err != nil {
			return nil, err
		}
	}

	if connection.getEncryptionKey() == nil {
		return buf.Bytes(), nil
	}

	return encrypt(buf.Bytes(), connection.getEncryptionKey())
}

// UnmarshalObject decodes an object from binary data
func (connection *DbConnection) UnmarshalObject(data []byte, object any) error {
	var err error
	if connection.getEncryptionKey() != nil {
		data, err = decrypt(data, connection.getEncryptionKey())
		if err != nil {
			return errors.Wrap(err, "Failed decrypting object")
		}
	}

	if e := json.Unmarshal(data, object); e != nil {
		// Special case for the VERSION bucket. Here we're not using json
		// So we need to return it as a string
		s, ok := object.(*string)
		if !ok {
			return errors.Wrap(err, e.Error())
		}

		*s = string(data)
	}

	return err
}

// mmm, don't have a KMS .... aes GCM seems the most likely from
// https://gist.github.com/atoponce/07d8d4c833873be2f68c34f9afc5a78a#symmetric-encryption

func encrypt(plaintext []byte, passphrase []byte) (encrypted []byte, err error) {
	block, _ := aes.NewCipher(passphrase)
	gcm, err := cipher.NewGCM(block)
	if err != nil {
		return encrypted, err
	}

	nonce := make([]byte, gcm.NonceSize())
	if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
		return encrypted, err
	}

	return gcm.Seal(nonce, nonce, plaintext, nil), nil
}

func decrypt(encrypted []byte, passphrase []byte) (plaintextByte []byte, err error) {
	if string(encrypted) == "false" {
		return []byte("false"), nil
	}

	block, err := aes.NewCipher(passphrase)
	if err != nil {
		return encrypted, errors.Wrap(err, "Error creating cypher block")
	}

	gcm, err := cipher.NewGCM(block)
	if err != nil {
		return encrypted, errors.Wrap(err, "Error creating GCM")
	}

	nonceSize := gcm.NonceSize()
	if len(encrypted) < nonceSize {
		return encrypted, errEncryptedStringTooShort
	}

	nonce, ciphertextByteClean := encrypted[:nonceSize], encrypted[nonceSize:]

	plaintextByte, err = gcm.Open(nil, nonce, ciphertextByteClean, nil)
	if err != nil {
		return encrypted, errors.Wrap(err, "Error decrypting text")
	}

	return plaintextByte, err
}