import torch

from colossalai.context.singleton_meta import SingletonMeta
from colossalai.utils import get_current_device


class FDIntermTensors(metaclass=SingletonMeta):
    """Singleton class to hold tensors used for storing intermediate values in flash-decoding.
    For now, it holds intermediate output and logsumexp (which will be used in reduction step along kv)
    """

    def __init__(self):
        self._tensors_initialized = False

    def _reset(self):
        self._tensors_initialized = False
        del self._mid_output
        del self._mid_output_lse
        del self._exp_sums
        del self._max_logits

    @property
    def is_initialized(self):
        return self._tensors_initialized

    @property
    def mid_output(self):
        assert self.is_initialized, "Intermediate tensors not initialized yet"
        return self._mid_output

    @property
    def mid_output_lse(self):
        assert self.is_initialized, "Intermediate tensors not initialized yet"
        return self._mid_output_lse

    @property
    def exp_sums(self):
        assert self.is_initialized, "Intermediate tensors not initialized yet"
        return self._exp_sums

    @property
    def max_logits(self):
        assert self.is_initialized, "Intermediate tensors not initialized yet"
        return self._max_logits

    def initialize(
        self,
        max_batch_size: int,
        num_attn_heads: int,
        kv_max_split_num: int,
        head_dim: int,
        dtype: torch.dtype = torch.float32,
        device: torch.device = get_current_device(),
    ) -> None:
        """Initialize tensors.

        Args:
            max_batch_size (int): The maximum batch size over all the model forward.
                This could be greater than the batch size in attention forward func when using dynamic batch size.
            num_attn_heads (int)): Number of attention heads.
            kv_max_split_num (int): The maximum number of blocks splitted on kv in flash-decoding algorithm.
                **The maximum length/size of blocks splitted on kv should be the kv cache block size.**
            head_dim (int): Head dimension.
            dtype (torch.dtype, optional): Data type to be assigned to intermediate tensors.
            device (torch.device, optional): Device used to initialize intermediate tensors.
        """
        assert not self.is_initialized, "Intermediate tensors used for Flash-Decoding have been initialized."

        self._mid_output = torch.empty(
            size=(max_batch_size, num_attn_heads, kv_max_split_num, head_dim), dtype=dtype, device=device
        )
        self._mid_output_lse = torch.empty(
            size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
        )
        self._exp_sums = torch.empty(
            size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
        )
        self._max_logits = torch.empty(
            size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
        )

        self._tensors_initialized = True