from typing import Callable, List, Optional, Tuple, Union import torch from colossalai.inference.struct import Sequence from colossalai.utils import get_current_device class BatchBucket: """Container for a batch of Sequences, which is used to manage the batch of sequences. Attrs: _sequences_dict (Dict[int, Sequence]): Map sequence uid to sequence struct seq_uid -> Sequence _sequences_indexes (Dict[int, int]): Map sequence uid to index in the batch seq_uid -> index in the batch (indexing used in sequence_lengths and block_tables) _sequence_lengths (torch.Tensor): Length of each sequence in the batch. The size of the tensor is (max_batch_size,) _block_tables (torch.Tensor): Block table of each sequence in the batch The size of the tensor is (max_batch_size, max_blocks_per_seq) """ def __init__( self, num_heads, head_dim, max_batch_size, max_length, block_size, kv_max_split_num, fd_interm_tensor=None, device=None, dtype=torch.float16, enable_streamingllm: bool = False, start_token_size: int = 4, generated_token_size: int = 512, ): self.num_heads = num_heads self.head_dim = head_dim self.max_batch_size = max_batch_size self.max_length = max_length # in + out len self.block_size = block_size self.kv_max_split_num = kv_max_split_num # Hint used for flash decoding self.fd_interm_tensor = fd_interm_tensor self.device = device or get_current_device() self.dtype = dtype self._use_spec_dec = False self._num_tokens_to_verify = None self.enable_streamingllm = enable_streamingllm self.start_token_size = start_token_size self.generated_token_size = generated_token_size self._current_batch_size = 0 self._sequences_dict = dict() self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size) self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32) self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths) if enable_streamingllm: max_blocks_per_seq = (start_token_size + generated_token_size + block_size - 1) // block_size + 1 else: max_blocks_per_seq = (self.max_length + block_size - 1) // block_size self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32) self._block_tables_helper = torch.full_like(self._block_tables, -1) @property def is_empty(self): return self._current_batch_size == 0 @property def current_batch_size(self): return self._current_batch_size def __len__(self): return self._current_batch_size @property def available_batch_size(self): return self.max_batch_size - self._current_batch_size @property def block_tables(self): return self._block_tables @property def seq_lengths(self): return self._sequence_lengths @property def seqs_ids(self): return list(self._sequences_dict.keys()) @property def seqs_li(self): return list(self._sequences_dict.values()) @property def is_compact(self): assert len(self._sequences_dict) == len(self._sequences_indexes), "BatchBucket indexing is not consistent" return ( len(self._sequences_dict) == torch.nonzero(self._sequence_lengths).view(-1).numel() == torch.nonzero(self._block_tables[:, 0] >= 0).numel() ) @property def use_spec_dec(self) -> bool: return self._use_spec_dec @property def num_tokens_to_verify(self) -> int: return self._num_tokens_to_verify @property def batch_token_ids(self) -> List[List[int]]: out = [] for seq in self.seqs_li: out.append(seq.input_token_id + seq.output_token_id) return out def streamingllm_update_batch(self, start_token_size: int, generated_token_size: int): """ Update sequence_lengths and block_tables when it is necessary to swap out a block. """ updated_block_ids = [] if self.current_batch_size > 0: need_update = False sequence_lengths_list = self._sequence_lengths.tolist() block_tables_list = self._block_tables[: self._current_batch_size].tolist() for batch_id in range(self.current_batch_size): # We assume that the start token occupies the entire first block. if sequence_lengths_list[batch_id] == start_token_size + generated_token_size + self.block_size - 1: need_update = True sequence_lengths_list[batch_id] = start_token_size + generated_token_size - 1 block_id = block_tables_list[batch_id].pop(1) updated_block_ids.append(block_id) block_tables_list[batch_id].append(-1) if need_update: self._sequence_lengths = torch.tensor( sequence_lengths_list, dtype=self._sequence_lengths.dtype, device=self.device ) self._block_tables = torch.tensor(block_tables_list, dtype=self._block_tables.dtype, device=self.device) return updated_block_ids def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None: """Set batch bucket to use speculatvie decoding. This will notify the adjust the lengths of inputs during modeling, and let the main model verifies tokens in parallel. """ self._use_spec_dec = True self._num_tokens_to_verify = num_tokens_to_verify def reset_use_spec_dec(self) -> None: """Reset the usage of speculative decoding for the batch bucket""" self._use_spec_dec = False self._num_tokens_to_verify = None def _make_compact(self) -> None: # Clean and Compress the batch based on its sequences dict. # Namely,compress sequences to the front and clean the seq lengths and block tables tensors. # NOTE Prevent calling this method multiple times in a single step if self.is_compact: return valid_seq_ids = self._sequences_dict.keys() valid_num = len(valid_seq_ids) valid_indexes = [self._sequences_indexes[seq_id] for seq_id in valid_seq_ids] assert valid_num == len(self._sequences_indexes), "BatchBucket indexing is not consistent" self._sequence_lengths_helper[:valid_num] = self._sequence_lengths[valid_indexes] self._sequence_lengths[:] = self._sequence_lengths_helper[:] self._block_tables_helper[:valid_num, :] = self.block_tables[valid_indexes] self.block_tables[:] = self._block_tables_helper[:] new_idx = 0 for seq_id in valid_seq_ids: self._sequences_indexes[seq_id] = new_idx new_idx += 1 self._sequence_lengths_helper.fill_(0) self._block_tables_helper.fill_(-1) self._current_batch_size = valid_num def add_seq( self, seq: Sequence, alloc_block_table: torch.Tensor = None, alloc_block_table_fn: Callable[[torch.Tensor, int], None] = None, ) -> Union[torch.Tensor, None]: """Add a single sequence to the batch. User could opt to provide either a block table or a function to allocate block tables. Args: seq (Sequence): The sequence to be added to the batch alloc_block_table (torch.Tensor): The block tables to be copied and used for the sequence alloc_block_table_fn (Callable[[torch.Tensor, int], None]): The function to allocate blocks for the sequence, which is expected to reserve blocks and update status of kv-cache manager. Returns: block_table (torch.Tensor): The block table of the added sequence, used for block allocation in kv-cache manager. None if the sequence cannot be added. """ block_table = None # TODO might consider sorting by length if self._current_batch_size < self.max_batch_size: self._sequences_dict[seq.request_id] = seq self._sequences_indexes[seq.request_id] = self._current_batch_size self._sequence_lengths[self._current_batch_size] = seq.sentence_len # NOTE the added seq still require block table allocation by kvcache manager block_table = self._block_tables[self._current_batch_size - 1] if alloc_block_table is not None: # copy block ids from provided block tables self._block_tables[self._current_batch_size - 1] = alloc_block_table elif alloc_block_table_fn: alloc_block_table_fn(block_table, self._sequence_lengths[self._current_batch_size - 1].item()) self._current_batch_size += 1 return block_table def add_seqs( self, seqs: List[Sequence], alloc_block_tables: torch.Tensor = None, alloc_block_tables_fn: Callable[[torch.Tensor, torch.Tensor], None] = None, ) -> Union[torch.Tensor, None]: """Add a list of sequences to the batch. User could opt to provide either block tables or a function to allocate block tables. Args: seqs (List[Sequence]): The sequences to be added to the batch alloc_block_tables (torch.Tensor): The block tables to be copied and used for the sequence alloc_block_table_fn (Callable[[torch.Tensor, torch.Tensor], None]): The function to allocate blocks for multiple sequences, which is expected to reserve blocks and update status of kv-cache manager. Returns: block_tables (torch.Tensor): The block tables of the added sequences, used for block allocation in kv-cache manager. None if the sequences cannot be added. """ assert ( alloc_block_tables is None or alloc_block_tables_fn is None ), "`alloc_block_tables` and `alloc_block_tables_fn` cannot be provided at the same time" num_seqs_to_add = min(self.max_batch_size - self._current_batch_size, len(seqs)) block_tables = None if num_seqs_to_add > 0: for i, seq in enumerate(seqs[:num_seqs_to_add]): self._sequences_dict[seq.request_id] = seq self._sequences_indexes[seq.request_id] = self._current_batch_size + i # TODO external (rename): modify Sequence.sentence_len to seq_len self._sequence_lengths[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = ( torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32) ) # NOTE block tables to be updated by kvcache manager block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] if alloc_block_tables is not None: # copy block ids from provided block tables self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = ( alloc_block_tables ) elif alloc_block_tables_fn: alloc_block_tables_fn( block_tables, self._sequence_lengths[self._current_batch_size : self._current_batch_size + num_seqs_to_add], ) self._current_batch_size += num_seqs_to_add seqs[:] = seqs[num_seqs_to_add:] return block_tables def pop_seq_update_batch( self, request_id: int, free_block_table_fn: Callable[[torch.Tensor], None] = None ) -> Tuple[Sequence, Union[torch.Tensor, None]]: """Pop a single sequence by id from the batch, and update the batch bucket status. Args: request_id (int): The uid of the sequence free_block_table_fn (Callable): The function to free the block table of a sequence, if not provided, then we have to release the block table manually after calling this method Returns: A tuple of: seq (Sequence): The target sequence and block_table (torch.Tensor): block table of the target sequence indicating corresponding blocks, none if the sequence is not found or free_block_table_fn is provided. """ seq: Sequence = self._sequences_dict.get(request_id) block_table = None if seq is not None: assert request_id in self._sequences_indexes, "Inconsistency in BatchBucket indexing" self._sequences_dict.pop(request_id) seq_b_idx = self._sequences_indexes.get(request_id) if self.current_batch_size > 1: # replace seq length of the target seq with that of the last seq in the batch last_seq_b_idx = self.current_batch_size - 1 last_seq_id = next( (uid for uid, index in self._sequences_indexes.items() if index == last_seq_b_idx), None, ) assert last_seq_id is not None self._sequences_indexes[last_seq_id] = seq_b_idx self._sequence_lengths[seq_b_idx] = self._sequence_lengths[last_seq_b_idx] self._sequence_lengths[last_seq_b_idx].fill_(0) # free the block table of the seq, or return a copy of the block table (to be processed outside) if free_block_table_fn: free_block_table_fn(self._block_tables[seq_b_idx]) else: block_table = self._block_tables[seq_b_idx].detach().clone() # replace block table of the target seq with that of the last seq in the batch self._block_tables[seq_b_idx] = self._block_tables[last_seq_b_idx] self._block_tables[last_seq_b_idx].fill_(-1) else: if free_block_table_fn: free_block_table_fn(self._block_tables[0]) else: block_table = self._block_tables[0].detach().clone() self._sequence_lengths[0].fill_(0) self._block_tables[0].fill_(-1) self._sequences_indexes.pop(request_id) self._current_batch_size -= 1 return seq, block_table def pop_seqs( self, request_ids: List[int], free_block_table_fn: Callable[[torch.Tensor], None] = None ) -> Tuple[List[Sequence], List[torch.Tensor]]: """Iteratively pop a list of sequences by uid. Args: request_ids (List[int]): The uids of the sequences free_block_table_fn (Callable): The function to free the block table of a sequence, if not provided, then we have to release the block table manually after calling this method Returns: A tuple of: seqs (List[Sequence]): The target sequences and block_tables (List[torch.Tensor]): block tables of the target sequences indicating corresponding blocks """ seqs = [] block_tables = [] for request_id in request_ids: seq, block_table = self.pop_seq_update_batch(request_id, free_block_table_fn) if seq is not None: seqs.append(seq) if block_table is not None: block_tables.append(block_table) return seqs, block_tables def pop_n_seqs( self, n: int, free_block_table_fn: Callable[[torch.Tensor], None] = None ) -> Tuple[List[Sequence], List[torch.Tensor]]: """Pop the first n sequences in the batch (FIFO). If n is greater than the current batch szie, pop all the sequences in the batch. Args: n (int): The number of sequences to pop out free_block_table_fn (Callable): The function to free the block table of a single sequence Returns: A tuple of: seqs (List[Sequence]): The target sequences, and block_tables (List[torch.Tensor]): block tables of the target sequences indicating corresponding blocks """ # NOTE Prevent calling this method multiple times in a single step seqs = [] block_tables = [] n = min(n, self.current_batch_size) seq_ids = list(self._sequences_dict.keys())[:n] for seq_id in seq_ids: seq = self._sequences_dict.pop(seq_id) seq_b_idx = self._sequences_indexes.pop(seq_id) if free_block_table_fn: free_block_table_fn(self.block_tables[seq_b_idx]) else: block_tables.append(self.block_tables[seq_b_idx].detach().clone()) seqs.append(seq) if not self.is_compact: self._make_compact() return seqs, block_tables def pop_finished( self, free_block_table_fn: Callable[[torch.Tensor], None] = None ) -> Tuple[List[Sequence], List[torch.Tensor]]: """Pop finished sequences in the batch and a list of block tables of the finished sequences, if free_block_table_fn is not provided. Args: free_block_table_fn (Callable): The function to free the block table of a single sequence Returns: A tuple of: finished_seqs (List[Sequence]): The finished sequences, and finished_block_tables (List[torch.Tensor]): block tables of the finished sequences. """ finished_seqs = [] finished_block_tables = [] for seq in self._sequences_dict.values(): if seq.check_finish(): finished_seqs.append(seq) # Use `pop_seq_update_batch`` to update the batch status for just a few of finished seqs, # otherwise, pop seqs directly and then call `_make_compact` to compress the batch. # For now, the performance difference is not significant, so we use the frist method to pop seqs. # Precise evaluations to be done. for seq in finished_seqs: _, block_table = self.pop_seq_update_batch(seq.request_id, free_block_table_fn) if block_table is not None: finished_block_tables.append(block_table) return finished_seqs, finished_block_tables # TODO arg type not support beam search sampling yet def append_batch_tokens(self, tokens: torch.Tensor) -> None: """Append a batch of tokens to the sequences in the batch""" assert self.current_batch_size == tokens.size(0), "Batch size mismatch" if self.current_batch_size > 0: tokens = tokens.tolist() for seq_id, seq in self._sequences_dict.items(): index_in_b = self._sequences_indexes[seq_id] curr_tokens = tokens[index_in_b] if not isinstance(curr_tokens, list): curr_tokens = [curr_tokens] seq.output_token_id += curr_tokens seq.check_finish() self._sequence_lengths[: self.current_batch_size] += 1 def revoke_batch_tokens(self, n_tokens: int, n_seqs: int = 1) -> None: """Revoke the last n output tokens of the sequences in the batch Args: n_tokens (int): The number of output tokens to revoke from each sequence. It does not count in the context tokens (input tokens). n_seqs (int): The first n sequences to revoke tokens from. Defaults to 1. For now, speculative decoding only supports batch size 1. """ if n_tokens >= 1: seqs_iter = iter(self._sequences_dict.items()) for _ in range(n_seqs): seq_id, seq = next(seqs_iter) assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence" seq.output_token_id = seq.output_token_id[:-n_tokens] seq.revoke_finished_status() self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]: """Clear all the sequences in the batch. free_block_tables_fn (Optional[Callable]): The function to free the block tables of all the sequences in a batch """ seqs = list(self._sequences_dict.values()) self._sequences_dict.clear() self._sequences_indexes.clear() if free_block_tables_fn: free_block_tables_fn(self.block_tables, self._current_batch_size) self._block_tables.fill_(-1) self._sequence_lengths.fill_(0) self._current_batch_size = 0 return seqs def merge(self, other: "BatchBucket") -> List[int]: """Merge the sequences in the other batch into the current batch. Merge as possible as the current batch can, if it does not have available spaces holding all the sequences in the other batch Usage: > New incoming sequence added to prefil batch prefill bb curr batch size < prefil_ratio * prefill bb max batch size > New incoming sequence added to prefil batch prefill bb curr batch size == prefil_ratio * prefill bb max batch size > Pause Decoding > Prefill > Move sequences in prefill bb => decoding bb > Put back the out-of-volume sequences into the running pool Returns: unmerged_ids (List[int]): a list of sequence uids that are not merged into the current batch """ unmerged_ids = [] num_seqs_to_merge = min(self.available_batch_size, other.current_batch_size) if num_seqs_to_merge > 0: seqs, block_tables_li = other.pop_n_seqs(num_seqs_to_merge) block_tables = torch.stack(block_tables_li) self.add_seqs(seqs, alloc_block_tables=block_tables) unmerged_ids = other.seqs_ids return unmerged_ids ########## The following methods are expected to be used in modeling ########### # For compatibility. # NOTE: This is an assumption way to determine the stage of the batch. @property def is_prompts(self) -> bool: assert len(self._sequences_dict) > 0, "No sequence in the batch" first_seq = next(iter(self._sequences_dict.values())) if first_seq.output_len == 0: return True return False def get_1D_inputs_spec_dec(self, n: int) -> torch.Tensor: # Used for main model verification in **Decoding Stage** # `n` is the number of tokens to be verified, # and so that prepare the last `n` tokens of each sequence as the inputs assert len(self._sequences_dict) > 0, "No sequence in the batch" assert all( seq.output_len >= n for seq in self._sequences_dict.values() ), "Sequence output tokens must be greater than or equal to the number of tokens to be verified." out_li = [] seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x]) for seq_id in seq_ids: seq: Sequence = self._sequences_dict[seq_id] out_li.extend(seq.output_token_id[-n:]) return torch.tensor(out_li, dtype=torch.long, device=self.device) # For compatibility def get_1D_inputs(self) -> torch.Tensor: assert len(self._sequences_dict) > 0, "No sequence in the batch" first_seq = next(iter(self._sequences_dict.values())) # not exactly the first sequence if first_seq.output_len == 0: # Assume prefill stage assert all( seq.output_len == 0 for seq in self._sequences_dict.values() ), "Sequence stage (Prefill/Decoding) must be the same in the batch" out_li = [] seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x]) for seq_id in seq_ids: seq: Sequence = self._sequences_dict[seq_id] out_li.extend(seq.input_token_id) return torch.tensor(out_li, dtype=torch.long, device=self.device) else: # Assume decoding stage if self.use_spec_dec: # For Speculative Decoding # the number of tokens to be verified in parallel plus the correct token in the last step return self.get_1D_inputs_spec_dec(self.num_tokens_to_verify + 1) assert all( seq.output_len > 0 for seq in self._sequences_dict.values() ), "Sequence stage (Prefill/Decoding) must be the same in the batch" assert self.is_compact, "BatchBucket is not compact" out = torch.empty([self.current_batch_size], dtype=torch.long) for seq_id, index_in_b in self._sequences_indexes.items(): seq: Sequence = self._sequences_dict[seq_id] out[index_in_b] = seq.output_token_id[-1] return out.to(device=self.device) # For compatibility def get_block_table_tensor(self) -> torch.Tensor: assert self.is_compact # Debug usage block_table = self.block_tables[: self.current_batch_size] return block_table.to(device=self.device) # For compatibility def get_sequence_lengths(self) -> torch.Tensor: assert self.is_compact # Debug usage sequence_lengths = self.seq_lengths[: self.current_batch_size] return sequence_lengths.to(device=self.device) # For compatibility @property def fd_inter_tensor(self) -> None: assert self.fd_interm_tensor is not None, "fd_interm_tensor is not provided" return self.fd_interm_tensor def __repr__(self) -> str: return f"(sequences_dict={self._sequences_dict}, is_prompts={self.is_prompts})"