import enum from dataclasses import dataclass from typing import Any, List, Tuple, Union import torch from ordered_set import OrderedSet from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.logging import get_dist_logger logger = get_dist_logger(__name__) """ The abstraction of request and sequence are defined here. """ class RequestStatus(enum.Enum): """ The status of Sentences """ # running status WAITING = enum.auto() RUNNING = enum.auto() ABORTED = enum.auto() # completion status OVERLENGTH = enum.auto() COMPLETED = enum.auto() LENGTH_CAPPED = enum.auto() # recycle status RECYCLED = enum.auto() @staticmethod def is_finished(status: "RequestStatus") -> bool: return status in [ RequestStatus.OVERLENGTH, RequestStatus.COMPLETED, RequestStatus.LENGTH_CAPPED, ] @staticmethod def is_running(status: "RequestStatus") -> bool: return status == RequestStatus.RUNNING @staticmethod def is_waiting(status: "RequestStatus") -> bool: return status == RequestStatus.WAITING @dataclass class Sequence: """Store information of input sequence. Args: request_id (int): The ID of input sequence. prompt (str): The prompt of input sequence. input_token_id (List[int]): The tokens ID of input sequence. block_size (int): The block size of input sequence. sample_params (SampleParams): The sample_params of input sequence. block_table (torch.Tensor): The index of input sequence in block_table. eos_token_id (int): The eos token id for this inference process. pad_token_id (int): The pad token id for this inference process. max_output_len (int): Maximum output length. """ request_id: int prompt: str input_token_id: List[int] block_size: int sample_params: Any # SampleParams needs to be imported later. eos_token_id: int pad_token_id: int max_output_len: int = 256 def __post_init__(self): self.output_token_id = [] self.status = RequestStatus.WAITING @property def sentence_len(self) -> int: """ Get length of current sentence. """ return len(self.input_token_id) + len(self.output_token_id) @property def input_len(self) -> int: """ Get length of input sentence. """ return len(self.input_token_id) @property def output_len(self) -> int: """ Get length of output sentence. """ return len(self.output_token_id) def check_finish(self) -> bool: """ Check whether the inference is finished. Returns: bool: Whether the inference is finished. """ if RequestStatus.is_finished(self.status): return True if self.output_token_id: if self.output_token_id[-1] == self.eos_token_id or self.output_len >= self.max_output_len: self.status = RequestStatus.COMPLETED return True return False def __hash__(self): return hash(self.request_id) def mark_running(self) -> None: """ Set status for prefill reqs. """ assert ( self.status == RequestStatus.WAITING or RequestStatus.RECYCLED ), "Sequence is not in WAITTING/RECYCLED STATUS" self.status = RequestStatus.RUNNING def mark_finished(self) -> None: """ Set status for finished reqs. """ self.status = RequestStatus.COMPLETED def mark_aborted(self) -> None: """ Set status for aborted reqs. """ self.status = RequestStatus.ABORTED def recycle(self) -> None: """ Recycle a running sequnce to waiitting list """ assert ( not self.check_finish() and not self.status == RequestStatus.ABORTED ), "The running sequence \ is already done but it still in running list" self.status = RequestStatus.RECYCLED def __repr__(self) -> str: return ( f"(request_id={self.request_id}, " f"prompt={self.prompt}, " f"status={self.status.name}, " f"sample_params={self.sample_params}, " f"input_len={self.input_len}," f"output_len={self.output_len})" ) @dataclass class BatchInfo: """ Information to be passed and used for a batch of sequences. """ max_batch_size: int kv_max_split_num: int num_heads: int head_dim: int sequences_set: OrderedSet[Sequence] = None is_prompts: bool = True device: torch.device = None dtype: torch.dtype = None fd_inter_tensor: FDIntermTensors = None def __post_init__(self): if self.device is None: self.device = torch.cuda.current_device() if self.sequences_set is None: self.sequences_set = OrderedSet() if self.fd_inter_tensor is None: self.fd_inter_tensor = FDIntermTensors() def init_fd_tensors(self): if not self.fd_inter_tensor.is_initialized: self.fd_inter_tensor.initialize( max_batch_size=self.max_batch_size, num_attn_heads=self.num_heads, kv_max_split_num=self.kv_max_split_num, head_dim=self.head_dim, dtype=self.dtype, device=self.device, ) def get_block_table_tensor(self) -> None: tesnor_list = [] block_table = None assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." for seq in self.sequences_set: block_table = seq.block_table assert ( block_table is not None ), f"The sequence(request_id {seq.request_id}) has not initialized the block_table." tesnor_list.append(seq.block_table) block_table = torch.stack(tesnor_list) return block_table def clear_batch(self) -> None: """ Clear sequence set and block table if we need to abort this batch. Prefill: clear sequence set and move them to running batch(external) Decoding: mark unfinished sequences as aborted. """ if self.is_prompts: self.sequences_set.clear() else: for seq in self.sequences_set: seq.mark_aborted() if seq.check_finish(): seq.mark_finished() self.sequences_set.clear() def fliter_batch(self) -> List["Sequence"]: """ Remove completed sentences from a batch. Returns: List["Sequence"]: List of finished sequences. """ finish_seqs = [] for seq in self.sequences_set: if seq.check_finish(): finish_seqs.append(seq) for finish_seq in finish_seqs: self.sequences_set.discard(finish_seq) return finish_seqs def abort_seq(self, seq: "Sequence") -> "Sequence": """ Remove sequence from the batch. """ if not seq.check_finish(): seq.status = RequestStatus.ABORTED self.sequences_set.discard(seq) return seq def add_seqs(self, seqs: Union[Sequence, List[Sequence]]) -> None: """ Add new sequence to batch Args: seqs (List["Sequence"]): The list of new sequences. """ # covnert single sequence to list if isinstance(seqs, Sequence): seqs = [seqs] for seq in seqs: if seq in self.sequences_set: logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.") continue self.sequences_set.add(seq) def del_seq(self, seq: Sequence) -> Sequence: """ Delete sequence in batch """ self.sequences_set.discard(seq) @property def is_empty(self) -> None: """ Check whether sequences_set is empty. """ return not self.sequences_set def update_batch_tokens(self, tokens: Union[List[int], List[List[int]], torch.Tensor]) -> None: """ Add an output token for each sentence in the batch. Args: tokens (List[int]): A batch of tokens """ if isinstance(tokens, torch.Tensor): tokens = tokens.tolist() assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size." for seq, token in zip(self.sequences_set, tokens): if not isinstance(token, list): if not isinstance(token, int): raise TypeError(f"The token type must be List[int] or int, but got {type(token)}.") token = [token] seq.output_token_id += token seq.check_finish() def get_batch_size(self) -> int: """ Get batch_size of this batch """ return len(self.sequences_set) def get_batch_inputs(self) -> torch.LongTensor: """ Get bacth inputs for forward inference computation. """ input_list = [] assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." for seq in self.sequences_set: if self.is_prompts: if seq.output_len > 0: input_list.append(seq.input_token_id + seq.output_token_id) else: input_list.append(seq.input_token_id) else: input_list.append([seq.output_token_id[-1]]) max_seq_len = max(len(sub_list) for sub_list in input_list) # We assume that all the padding_id in seq are the same at present. return _make_tensor_with_pad(input_list, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int) def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]: """ Flattening the input tokens. """ input_list = [] assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." for seq in self.sequences_set: if self.is_prompts: input_list.extend(seq.input_token_id) else: input_list.append(seq.output_token_id[-1]) return torch.tensor(input_list, dtype=torch.long, device=self.device) def get_sequence_lengths(self): """ Get the input_len of each sentence in this batch. """ len_list = [] assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." for seq in self.sequences_set: len_list.append(seq.sentence_len) return torch.tensor(len_list, dtype=torch.int, device=self.device) def get_attn_mask(self) -> torch.Tensor: """ Generate and return attention mask. """ assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." past_values = [] # We assume that all the padding_id in seq are the same at present. padding_id = self.sequences_set[0].pad_token_id for seq in self.sequences_set: past_values.append(seq.input_token_id + seq.output_token_id) max_seq_len = max(len(sub_list) for sub_list in past_values) attn_mask = _make_tensor_with_pad( past_values, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int, device=self.device ) return attn_mask.ne(padding_id).long() def __repr__(self) -> str: return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})" def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: assert len(x) <= max_len return [pad] * (max_len - len(x)) + x def _make_tensor_with_pad( x: Union[List[List[int]], List[int]], max_len: int, pad: int, dtype: torch.dtype, device: Union[str, torch.device] = "cuda", pin_memory: bool = False, ): padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] return torch.tensor(padded_x, dtype=dtype, device=device, pin_memory=pin_memory and str(device) == "cpu")