from dataclasses import dataclass
from typing import Optional, Tuple

import torch


@dataclass
class DrafterOutput:
    """
    Dataclass for drafter model outputs.

    Args:
        speculated_length (int): Speculated length of the output sequence
            It is always less than or equal to spec_num during drafter's speculation process
        logits (torch.FloatTensor): Logits of the output sequence
        next_tokens (torch.Tensor): Next token ids
        past_key_values (Optional[Tuple[Tuple[torch.FloatTensor]]]): Past key values of the output sequence
    """

    speculated_length: int = None
    logits: torch.FloatTensor = None
    next_tokens: torch.Tensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None

    def __post_init__(self):
        assert self.speculated_length is not None and self.speculated_length >= 0
        if self.past_key_values is not None:
            assert isinstance(self.past_key_values, tuple), "Past key values should be a tuple"
            assert all([isinstance(past_key_value, tuple) for past_key_value in self.past_key_values])


@dataclass
class GlideInput:
    """Dataclass for Glide Models (e.g. `colossalai/inference/modeling/models/glide_llama.py`).
    Used for pack data that will be used during glimpsing KV Caches of the main model.

    Args:
        block_tables (torch.Tensor): [num_seqs, max_blocks_per_seq] The block table of KV Caches.
        large_k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_size]
            Blocked key cache of the main model
        large_v_cache (torch.Tensor): Blocked value cache of the main model. It has the same shape as k cache.
        sequence_lengths (torch.Tensor): [num_seqs] Sequence lengths of the current batch.
    """

    block_tables: torch.Tensor = None
    large_k_cache: torch.Tensor = None
    large_v_cache: torch.Tensor = None
    sequence_lengths: torch.Tensor = None
    n_spec_tokens: int = 5

    @property
    def glimpse_ready(self):
        return all(
            attr is not None
            for attr in [self.block_tables, self.large_k_cache, self.large_v_cache, self.sequence_lengths]
        )