# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later
from dataclasses import dataclass

import torch

from .kvcache_manager import MemoryManager


@dataclass
class BatchInferState:
    r"""
    Information to be passed and used for a batch of inputs during
    a single model forward
    """
    batch_size: int
    max_len_in_batch: int

    cache_manager: MemoryManager = None

    block_loc: torch.Tensor = None
    start_loc: torch.Tensor = None
    seq_len: torch.Tensor = None
    past_key_values_len: int = None

    is_context_stage: bool = False
    context_mem_index: torch.Tensor = None
    decode_is_contiguous: bool = None
    decode_mem_start: int = None
    decode_mem_end: int = None
    decode_mem_index: torch.Tensor = None
    decode_layer_id: int = None

    device: torch.device = torch.device("cuda")

    @property
    def total_token_num(self):
        # return self.batch_size * self.max_len_in_batch
        assert self.seq_len is not None and self.seq_len.size(0) > 0
        return int(torch.sum(self.seq_len))

    def set_cache_manager(self, manager: MemoryManager):
        self.cache_manager = manager

    @staticmethod
    def init_block_loc(
        b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor
    ):
        """in-place update block loc mapping based on the sequence length of the inputs in current bath"""
        start_index = 0
        seq_len_numpy = seq_len.cpu().numpy()
        for i, cur_seq_len in enumerate(seq_len_numpy):
            b_loc[i, max_len_in_batch - cur_seq_len : max_len_in_batch] = alloc_mem_index[
                start_index : start_index + cur_seq_len
            ]
            start_index += cur_seq_len
        return