import enum
from dataclasses import dataclass
from typing import Any, List

from colossalai.inference.config import DiffusionGenerationConfig
from colossalai.logging import get_dist_logger

logger = get_dist_logger(__name__)

"""
The abstraction of request and sequence are defined here.
"""


class RequestStatus(enum.Enum):
    """
    The status of Sentences
    """

    # running status
    WAITING = enum.auto()
    RUNNING = enum.auto()
    ABORTED = enum.auto()

    # completion status
    OVERLENGTH = enum.auto()
    COMPLETED = enum.auto()
    LENGTH_CAPPED = enum.auto()

    # recycle status
    RECYCLED = enum.auto()

    @staticmethod
    def is_finished(status: "RequestStatus") -> bool:
        return status in [
            RequestStatus.OVERLENGTH,
            RequestStatus.COMPLETED,
            RequestStatus.LENGTH_CAPPED,
        ]

    @staticmethod
    def is_running(status: "RequestStatus") -> bool:
        return status == RequestStatus.RUNNING

    @staticmethod
    def is_waiting(status: "RequestStatus") -> bool:
        return status == RequestStatus.WAITING


@dataclass
class DiffusionSequence:
    """
    parameters for diffusion
    """

    request_id: int
    prompt: str
    generation_config: DiffusionGenerationConfig


@dataclass
class Sequence:
    """Store information of input sequence.

    Args:
        request_id (int): The ID of input sequence.
        prompt (str): The prompt of input sequence.
        input_token_id (List[int]): The tokens ID of input sequence.
        block_size (int): The block size of input sequence.
        sample_params (SampleParams): The sample_params of input sequence.
        block_table (torch.Tensor): The index of input sequence in block_table.
        eos_token_id (int): The eos token id for this inference process.
        pad_token_id (int): The pad token id for this inference process.
        max_output_len (int): Maximum output length.
        ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
        output(str): The output of sequence
    """

    request_id: int
    prompt: str
    input_token_id: List[int]
    block_size: int
    sample_params: Any  # SampleParams needs to be imported later.
    eos_token_id: int
    pad_token_id: int
    max_output_len: int = 256
    # NOTE(caidi) This is a temporary solution. It's better to move the logic to turn on or off the flag in sampling module in future.
    ignore_eos: bool = False
    output: str = None

    def __post_init__(self):
        self.output_token_id = []
        self.status = RequestStatus.WAITING

    @property
    def sentence_len(self) -> int:
        """
        Get length of current sentence.
        """
        return len(self.input_token_id) + len(self.output_token_id)

    @property
    def input_len(self) -> int:
        """
        Get length of input sentence.
        """
        return len(self.input_token_id)

    @property
    def output_len(self) -> int:
        """
        Get length of output sentence.
        """
        return len(self.output_token_id)

    def check_finish(self) -> bool:
        """
        Check whether the inference is finished.

        Returns:
            bool: Whether the inference is finished.
        """
        if RequestStatus.is_finished(self.status):
            return True

        if self.output_token_id:
            if (
                self.output_token_id[-1] == self.eos_token_id and not self.ignore_eos
            ) or self.output_len >= self.max_output_len:
                self.status = RequestStatus.COMPLETED
                return True

        return False

    def revoke_finished_status(self) -> None:
        """
        Revoke the finished status of the sequence.
        This is only used by speculative decoding for now.
        """
        if RequestStatus.is_finished(self.status):
            self.status = RequestStatus.RUNNING

    def __hash__(self):
        return hash(self.request_id)

    def mark_running(self) -> None:
        """
        Set status for prefill reqs.
        """
        assert (
            self.status == RequestStatus.WAITING or RequestStatus.RECYCLED
        ), "Sequence is not in WAITTING/RECYCLED STATUS"
        self.status = RequestStatus.RUNNING

    def mark_finished(self) -> None:
        """
        Set status for finished reqs.
        """
        self.status = RequestStatus.COMPLETED

    def mark_aborted(self) -> None:
        """
        Set status for aborted reqs.
        """
        self.status = RequestStatus.ABORTED

    def recycle(self) -> None:
        """
        Recycle a running sequnce to waiitting list
        """
        assert (
            not self.check_finish() and not self.status == RequestStatus.ABORTED
        ), "The running sequence \
        is already done but it still in running list"
        self.status = RequestStatus.RECYCLED

    def __repr__(self) -> str:
        return (
            f"(request_id={self.request_id}, "
            f"prompt={self.prompt},\n"
            f"output_token_id={self.output_token_id},\n"
            f"output={self.output},\n"
            f"status={self.status.name},\n"
            f"sample_params={self.sample_params},\n"
            f"input_len={self.input_len},\n"
            f"output_len={self.output_len})\n"
        )


def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
    assert len(x) <= max_len
    return [pad] * (max_len - len(x)) + x