diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py new file mode 100644 index 000000000..93d4c2004 --- /dev/null +++ b/colossalai/inference/batch_bucket.py @@ -0,0 +1,449 @@ +from typing import Callable, List, Optional, Tuple, Union + +import torch + +from colossalai.inference.struct import Sequence +from colossalai.utils import get_current_device + + +class BatchBucket: + """Container for a batch of Sequences, which is used to manage the batch of sequences. + + Attrs: + _sequences_dict (Dict[int, Sequence]): Map sequence uid to sequence struct + seq_uid -> Sequence + _sequences_indexes (Dict[int, int]): Map sequence uid to index in the batch + seq_uid -> index in the batch (indexing used in sequence_lengths and block_tables) + _sequence_lengths (torch.Tensor): Length of each sequence in the batch. + The size of the tensor is (max_batch_size,) + _block_tables (torch.Tensor): Block table of each sequence in the batch + The size of the tensor is (max_batch_size, max_blocks_per_seq) + """ + + def __init__( + self, + num_heads, + head_dim, + max_batch_size, + max_length, + block_size, + kv_max_split_num, + fd_interm_tensor=None, + device=None, + dtype=torch.float16, + ): + self.num_heads = num_heads + self.head_dim = head_dim + self.max_batch_size = max_batch_size + self.max_length = max_length # in + out len + self.block_size = block_size + self.kv_max_split_num = kv_max_split_num # Hint used for flash decoding + self.fd_interm_tensor = fd_interm_tensor + self.device = device or get_current_device() + self.dtype = dtype + + self._current_batch_size = 0 + self._sequences_dict = dict() + self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size) + self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32) + self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths) + max_blocks_per_seq = (self.max_length + block_size - 1) // block_size + self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32) + self._block_tables_helper = torch.full_like(self._block_tables, -1) + + @property + def is_empty(self): + return self._current_batch_size == 0 + + @property + def current_batch_size(self): + return self._current_batch_size + + @property + def available_batch_size(self): + return self.max_batch_size - self._current_batch_size + + @property + def block_tables(self): + return self._block_tables + + @property + def seq_lengths(self): + return self._sequence_lengths + + @property + def seqs_ids(self): + return list(self._sequences_dict.keys()) + + @property + def seqs_li(self): + return list(self._sequences_dict.values()) + + @property + def is_compact(self): + assert len(self._sequences_dict) == len(self._sequences_indexes), "BatchBucket indexing is not consistent" + return ( + len(self._sequences_dict) + == torch.nonzero(self._sequence_lengths).view(-1).numel() + == torch.nonzero(self._block_tables[:, 0] >= 0).numel() + ) + + def _make_compact(self) -> None: + # Clean and Compress the batch based on its sequences dict. + # Namely,compress sequences to the front and clean the seq lengths and block tables tensors. + # NOTE Prevent calling this method multiple times in a single step + if self.is_compact: + return + valid_seq_ids = self._sequences_dict.keys() + valid_num = len(valid_seq_ids) + valid_indexes = [self._sequences_indexes[seq_id] for seq_id in valid_seq_ids] + assert valid_num == len(self._sequences_indexes), "BatchBucket indexing is not consistent" + self._sequence_lengths_helper[:valid_num] = self._sequence_lengths[valid_indexes] + self._sequence_lengths[:] = self._sequence_lengths_helper[:] + self._block_tables_helper[:valid_num, :] = self.block_tables[valid_indexes] + self.block_tables[:] = self._block_tables_helper[:] + new_idx = 0 + for seq_id in valid_seq_ids: + self._sequences_indexes[seq_id] = new_idx + new_idx += 1 + self._sequence_lengths_helper.fill_(0) + self._block_tables_helper.fill_(-1) + self._current_batch_size = valid_num + + def add_seq( + self, + seq: Sequence, + alloc_block_table: torch.Tensor = None, + alloc_block_table_fn: Callable[[torch.Tensor, int], None] = None, + ) -> Union[torch.Tensor, None]: + """Add a single sequence to the batch. + User could opt to provide either a block table or a function to allocate block tables. + + Args: + seq (Sequence): The sequence to be added to the batch + alloc_block_table (torch.Tensor): The block tables to be copied and used for the sequence + alloc_block_table_fn (Callable[[torch.Tensor, int], None]): The function to allocate blocks for the sequence, + which is expected to reserve blocks and update status of kv-cache manager. + + Returns: + block_table (torch.Tensor): The block table of the added sequence, used for block allocation in kv-cache manager. + None if the sequence cannot be added. + """ + block_table = None + # TODO might consider sorting by length + if self._current_batch_size < self.max_batch_size: + self._sequences_dict[seq.request_id] = seq + self._sequences_indexes[seq.request_id] = self._current_batch_size + self._sequence_lengths[self._current_batch_size] = seq.sentence_len + # NOTE the added seq still require block table allocation by kvcache manager + block_table = self._block_tables[self._current_batch_size - 1] + if alloc_block_table is not None: + # copy block ids from provided block tables + self._block_tables[self._current_batch_size - 1] = alloc_block_table + elif alloc_block_table_fn: + alloc_block_table_fn(block_table, self._sequence_lengths[self._current_batch_size - 1].item()) + self._current_batch_size += 1 + return block_table + + def add_seqs( + self, + seqs: List[Sequence], + alloc_block_tables: torch.Tensor = None, + alloc_block_tables_fn: Callable[[torch.Tensor, torch.Tensor], None] = None, + ) -> Union[torch.Tensor, None]: + """Add a list of sequences to the batch. + User could opt to provide either block tables or a function to allocate block tables. + + Args: + seqs (List[Sequence]): The sequences to be added to the batch + alloc_block_tables (torch.Tensor): The block tables to be copied and used for the sequence + alloc_block_table_fn (Callable[[torch.Tensor, torch.Tensor], None]): The function to allocate blocks for multiple sequences, + which is expected to reserve blocks and update status of kv-cache manager. + + Returns: + block_tables (torch.Tensor): The block tables of the added sequences, used for block allocation in kv-cache manager. + None if the sequences cannot be added. + """ + + assert ( + alloc_block_tables is None or alloc_block_tables_fn is None + ), "`alloc_block_tables` and `alloc_block_tables_fn` cannot be provided at the same time" + + num_seqs_to_add = min(self.max_batch_size - self._current_batch_size, len(seqs)) + block_tables = None + if num_seqs_to_add > 0: + for i, seq in enumerate(seqs[:num_seqs_to_add]): + self._sequences_dict[seq.request_id] = seq + self._sequences_indexes[seq.request_id] = self._current_batch_size + i + # TODO external (rename): modify Sequence.sentence_len to seq_len + self._sequence_lengths[ + self._current_batch_size : self._current_batch_size + num_seqs_to_add + ] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32) + # NOTE block tables to be updated by kvcache manager + block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] + if alloc_block_tables is not None: + # copy block ids from provided block tables + self._block_tables[ + self._current_batch_size : self._current_batch_size + num_seqs_to_add + ] = alloc_block_tables + elif alloc_block_tables_fn: + alloc_block_tables_fn( + block_tables, + self._sequence_lengths[self._current_batch_size : self._current_batch_size + num_seqs_to_add], + ) + + self._current_batch_size += num_seqs_to_add + seqs[:] = seqs[num_seqs_to_add:] + + return block_tables + + def pop_seq_update_batch( + self, request_id: int, free_block_table_fn: Callable[[torch.Tensor], None] = None + ) -> Tuple[Sequence, Union[torch.Tensor, None]]: + """Pop a single sequence by id from the batch, and update the batch bucket status. + + Args: + request_id (int): The uid of the sequence + free_block_table_fn (Callable): The function to free the block table of a sequence, + if not provided, then we have to release the block table manually after calling this method + + Returns: + A tuple of: seq (Sequence): The target sequence + and block_table (torch.Tensor): block table of the target sequence indicating corresponding blocks, + none if the sequence is not found or free_block_table_fn is provided. + """ + seq: Sequence = self._sequences_dict.get(request_id) + block_table = None + if seq is not None: + assert request_id in self._sequences_indexes, "Inconsistency in BatchBucket indexing" + self._sequences_dict.pop(request_id) + seq_b_idx = self._sequences_indexes.get(request_id) + + if self.current_batch_size > 1: + # replace seq length of the target seq with that of the last seq in the batch + last_seq_b_idx = self.current_batch_size - 1 + last_seq_id = next( + (uid for uid, index in self._sequences_indexes.items() if index == last_seq_b_idx), + None, + ) + assert last_seq_id is not None + self._sequences_indexes[last_seq_id] = seq_b_idx + self._sequence_lengths[seq_b_idx] = self._sequence_lengths[last_seq_b_idx] + self._sequence_lengths[last_seq_b_idx].fill_(0) + # free the block table of the seq, or return a copy of the block table (to be processed outside) + if free_block_table_fn: + free_block_table_fn(self._block_tables[seq_b_idx]) + else: + block_table = self._block_tables[seq_b_idx].detach().clone() + # replace block table of the target seq with that of the last seq in the batch + self._block_tables[seq_b_idx] = self._block_tables[last_seq_b_idx] + self._block_tables[last_seq_b_idx].fill_(-1) + else: + if free_block_table_fn: + free_block_table_fn(self._block_tables[0]) + else: + block_table = self._block_tables[0].detach().clone() + self._sequence_lengths[0].fill_(0) + self._block_tables[0].fill_(-1) + self._sequences_indexes.pop(request_id) + self._current_batch_size -= 1 + + return seq, block_table + + def pop_seqs( + self, request_ids: List[int], free_block_table_fn: Callable[[torch.Tensor], None] = None + ) -> Tuple[List[Sequence], List[torch.Tensor]]: + """Iteratively pop a list of sequences by uid. + + Args: + request_ids (List[int]): The uids of the sequences + free_block_table_fn (Callable): The function to free the block table of a sequence, + if not provided, then we have to release the block table manually after calling this method + Returns: + A tuple of: seqs (List[Sequence]): The target sequences + and block_tables (List[torch.Tensor]): block tables of the target sequences indicating corresponding blocks + """ + seqs = [] + block_tables = [] + for request_id in request_ids: + seq, block_table = self.pop_seq_update_batch(request_id, free_block_table_fn) + if seq is not None: + seqs.append(seq) + if block_table is not None: + block_tables.append(block_table) + return seqs, block_tables + + def pop_n_seqs( + self, n: int, free_block_table_fn: Callable[[torch.Tensor], None] = None + ) -> Tuple[List[Sequence], List[torch.Tensor]]: + """Pop the first n sequences in the batch (FIFO). + If n is greater than the current batch szie, pop all the sequences in the batch. + + Args: + n (int): The number of sequences to pop out + free_block_table_fn (Callable): The function to free the block table of a single sequence + Returns: + A tuple of: seqs (List[Sequence]): The target sequences, + and block_tables (List[torch.Tensor]): block tables of the target sequences indicating corresponding blocks + """ + # NOTE Prevent calling this method multiple times in a single step + seqs = [] + block_tables = [] + n = min(n, self.current_batch_size) + seq_ids = list(self._sequences_dict.keys())[:n] + for seq_id in seq_ids: + seq = self._sequences_dict.pop(seq_id) + seq_b_idx = self._sequences_indexes.pop(seq_id) + if free_block_table_fn: + free_block_table_fn(self.block_tables[seq_b_idx]) + else: + block_tables.append(self.block_tables[seq_b_idx].detach().clone()) + seqs.append(seq) + if not self.is_compact: + self._make_compact() + return seqs, block_tables + + def pop_finished( + self, free_block_table_fn: Callable[[torch.Tensor], None] = None + ) -> Tuple[List[Sequence], List[torch.Tensor]]: + """Pop finished sequences in the batch and a list of block tables of the finished sequences, + if free_block_table_fn is not provided. + + Args: + free_block_table_fn (Callable): The function to free the block table of a single sequence + Returns: + A tuple of: finished_seqs (List[Sequence]): The finished sequences, + and finished_block_tables (List[torch.Tensor]): block tables of the finished sequences. + """ + finished_seqs = [] + finished_block_tables = [] + for seq in self._sequences_dict.values(): + if seq.check_finish(): + finished_seqs.append(seq) + # Use `pop_seq_update_batch`` to update the batch status for just a few of finished seqs, + # otherwise, pop seqs directly and then call `_make_compact` to compress the batch. + # For now, the performance difference is not significant, so we use the frist method to pop seqs. + # Precise evaluations to be done. + for seq in finished_seqs: + _, block_table = self.pop_seq_update_batch(seq.request_id, free_block_table_fn) + if block_table is not None: + finished_block_tables.append(block_table) + + return finished_seqs, finished_block_tables + + # TODO arg type not support beam search sampling yet + def append_batch_tokens(self, tokens: torch.Tensor) -> None: + """Append a batch of tokens to the sequences in the batch""" + assert self.current_batch_size == tokens.size(0), "Batch size mismatch" + + if self.current_batch_size > 0: + tokens = tokens.tolist() + for seq_id, seq in self._sequences_dict.items(): + index_in_b = self._sequences_indexes[seq_id] + curr_tokens = tokens[index_in_b] + if not isinstance(curr_tokens, list): + curr_tokens = [curr_tokens] + seq.output_token_id += curr_tokens + seq.check_finish() + self._sequence_lengths[: self.current_batch_size] += 1 + + def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]: + """Clear all the sequences in the batch. + + free_block_tables_fn (Optional[Callable]): The function to free the block tables of all the sequences in a batch + """ + seqs = list(self._sequences_dict.values()) + self._sequences_dict.clear() + self._sequences_indexes.clear() + if free_block_tables_fn: + free_block_tables_fn(self.block_tables, self._current_batch_size) + self._block_tables.fill_(-1) + self._sequence_lengths.fill_(0) + self._current_batch_size = 0 + return seqs + + def merge(self, other: "BatchBucket") -> List[int]: + """Merge the sequences in the other batch into the current batch. + Merge as possible as the current batch can, if it does not have available spaces + holding all the sequences in the other batch + + Usage: + > New incoming sequence added to prefil batch + prefill bb curr batch size < prefil_ratio * prefill bb max batch size + > New incoming sequence added to prefil batch + prefill bb curr batch size == prefil_ratio * prefill bb max batch size + > Pause Decoding + > Prefill + > Move sequences in prefill bb => decoding bb + > Put back the out-of-volume sequences into the running pool + + Returns: + unmerged_ids (List[int]): a list of sequence uids that are not merged into the current batch + """ + unmerged_ids = [] + num_seqs_to_merge = min(self.available_batch_size, other.current_batch_size) + if num_seqs_to_merge > 0: + seqs, block_tables_li = other.pop_n_seqs(num_seqs_to_merge) + block_tables = torch.stack(block_tables_li) + self.add_seqs(seqs, alloc_block_tables=block_tables) + unmerged_ids = other.seqs_ids + return unmerged_ids + + ########## The following methods are expected to be used in modeling ########### + + # For compatibility. + # NOTE: This is an assumption way to determine the stage of the batch. + @property + def is_prompts(self) -> bool: + assert len(self._sequences_dict) > 0, "No sequence in the batch" + first_seq = next(iter(self._sequences_dict.values())) + if first_seq.output_len == 0: + return True + return False + + # For compatibility + def get_1D_inputs(self) -> torch.Tensor: + assert len(self._sequences_dict) > 0, "No sequence in the batch" + first_seq = next(iter(self._sequences_dict.values())) # not exactly the first sequence + if first_seq.output_len == 0: + # Assume prefill stage + assert all( + seq.output_len == 0 for seq in self._sequences_dict.values() + ), "Sequence stage (Prefill/Decoding) must be the same in the batch" + out_li = [] + num_tokens = torch.sum(self._sequence_lengths) + out = torch.empty([num_tokens], dtype=torch.long) + seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x]) + for seq_id in seq_ids: + seq: Sequence = self._sequences_dict[seq_id] + out_li.extend(seq.input_token_id) + return torch.tensor(out_li, dtype=torch.long, device=self.device) + else: + # Assume decoding stage + assert all( + seq.output_len > 0 for seq in self._sequences_dict.values() + ), "Sequence stage (Prefill/Decoding) must be the same in the batch" + assert self.is_compact, "BatchBucket is not compact" + out = torch.empty([self.current_batch_size], dtype=torch.long) + for seq_id, index_in_b in self._sequences_indexes.items(): + seq: Sequence = self._sequences_dict[seq_id] + out[index_in_b] = seq.output_token_id[-1] + return out.to(device=self.device) + + # For compatibility + def get_block_table_tensor(self) -> torch.Tensor: + assert self.is_compact # Debug usage + block_table = self.block_tables[: self.current_batch_size] + return block_table.to(device=self.device) + + # For compatibility + def get_sequence_lengths(self) -> torch.Tensor: + assert self.is_compact # Debug usage + sequence_lengths = self.seq_lengths[: self.current_batch_size] + return sequence_lengths.to(device=self.device) + + # For compatibility + @property + def fd_inter_tensor(self) -> None: + assert self.fd_interm_tensor is not None, "fd_interm_tensor is not provided" + return self.fd_interm_tensor diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index a210fbf64..7ce4719e7 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -109,7 +109,7 @@ class InferenceConfig: ), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}" # check distributed - assert ( + assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or ( self.tp_size * self.pp_size == dist.get_world_size() ), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})" # check prompt template diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index bd078dbd5..ea2e341d4 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -42,7 +42,7 @@ class InferenceEngine: def __init__( self, model: nn.Module, - tokenizer: [Union[PreTrainedTokenizer, PreTrainedTokenizerFast]], + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], inference_config: InferenceConfig, verbose: bool = False, model_policy: Policy = None, @@ -254,20 +254,12 @@ class InferenceEngine: else: prompt = prompts[i] - max_blocks_per_sequence = ( - self.inference_config.max_input_len - + self.inference_config.max_output_len - + self.inference_config.block_size - - 1 - ) // self.inference_config.block_size - block_table = torch.full([max_blocks_per_sequence], -1, device=self.device) sequence = Sequence( request_id, prompt, prompts_token_ids[i], block_size, None, - block_table, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id, self.inference_config.max_output_len, diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 7e66cfe31..a331e9cf8 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -1,15 +1,16 @@ -from typing import List +from typing import Dict, List, Union import torch from transformers.configuration_utils import PretrainedConfig from transformers.generation import GenerationConfig +from colossalai.inference.batch_bucket import BatchBucket from colossalai.inference.config import InferenceConfig from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.kv_cache import KVCacheManager from colossalai.inference.logit_processors import logit_processor from colossalai.inference.sampler import * -from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence +from colossalai.inference.struct import RequestStatus, Sequence from colossalai.logging import get_dist_logger __all__ = ["RunningList", "RequestHandler"] @@ -24,45 +25,79 @@ class RunningList: Args: prefill_ratio: (float) A ratio for determing whether to perform prefill or not. - prefill: (List) List that contains default inputs, defaults to []. + _prefill (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence. + _decoding (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence. """ - def __init__(self, prefill_ratio: str, prefill: List[Sequence] = None): + def __init__(self, prefill_ratio: int, prefill: List[Sequence] = None) -> None: self.prefill_ratio = prefill_ratio - self.decoding: List[Sequence] = [] - self.prefill: List[Sequence] = prefill if prefill is not None else [] + self._decoding: Dict[int, Sequence] = dict() + self._prefill: Dict[int, Sequence] = ( + dict({seq.request_id: seq for seq in self._prefill}) if prefill is not None else dict() + ) - def append(self, seq: Sequence): - # add seq to prefilling list first. - self.prefill.append(seq) - - def find_seq(self, request_id): - for seq in self.decoding: - if request_id == seq.request_id: - return seq - for seq in self.prefill: - if request_id == seq.request_id: - return seq - return None + @property + def decoding(self): + return list(self._decoding.values()) + + @property + def prefill(self): + return list(self._prefill.values()) + + @property + def prefill_seq_num(self): + return len(self._prefill) + + @property + def decoding_seq_num(self): + return len(self._decoding) + + @property + def total_seq_num(self): + return self.prefill_seq_num + self.decoding_seq_num - def remove(self, seq: Sequence): - if seq in self.decoding: - self.decoding.remove(seq) - elif seq in self.prefill: - self.prefill.remove(seq) + def append(self, seq: Sequence): + assert (seq.request_id not in self._prefill) and ( + seq.request_id not in self._decoding + ), f"Sequence uid {seq.request_id} already exists." + self._prefill[seq.request_id] = seq + + def extend(self, seqs: List[Sequence]): + for seq in seqs: + self._prefill[seq.request_id] = seq + + def find_seq(self, request_id) -> Union[Sequence, None]: + seq = None + if request_id in self._decoding: + seq = self._decoding[request_id] + elif request_id in self._prefill: + seq = self._prefill[request_id] + return seq + + def remove(self, seq: Sequence) -> None: + if seq.request_id in self._decoding: + self._decoding.pop(seq.request_id) + elif seq.request_id in self._prefill: + self._prefill.pop(seq.request_id) else: - raise ValueError(f"sequence {seq.request_id} is not in running list") + raise ValueError(f"Sequence {seq.request_id} is not in running list") def ready_for_prefill(self): - if not self.decoding: - return len(self.prefill) > 0 - return len(self.prefill) / len(self.decoding) >= self.prefill_ratio + if not self._decoding: + return len(self._prefill) > 0 + return len(self._prefill) / len(self._decoding) >= self.prefill_ratio def is_empty(self): - return not self.decoding and not self.prefill + return not self._decoding and not self._prefill - def total_seq_num(self): - return len(self.decoding) + len(self.prefill) + def mark_prefill_running(self) -> None: + for seq_id in self._prefill: + self._prefill[seq_id].mark_running() + + def move_prefill_to_decoding(self, seq_ids: List[int]) -> None: + for seq_id in seq_ids: + assert seq_id in self._prefill, f"Sequence {seq_id} is not in prefill list" + self._decoding[seq_id] = self._prefill.pop(seq_id) class RequestHandler: @@ -110,25 +145,27 @@ class RequestHandler: # TODO In the continuous batching scenario, the batch size may be greater than max_batch_size, # which may cause bugs and this issue should be fixed later. - self.running_batch = BatchInfo( - max_batch_size=self.max_batch_size, - kv_max_split_num=kv_max_split_num, + self.running_bb = BatchBucket( num_heads=model_config.num_attention_heads, head_dim=head_dim, - is_prompts=False, - device=device, - dtype=self.dtype, - fd_inter_tensor=fd_inter_tensor, - ) - self.prefill_batch = BatchInfo( max_batch_size=self.max_batch_size, + max_length=inference_config.max_input_len + inference_config.max_output_len, + block_size=inference_config.block_size, kv_max_split_num=kv_max_split_num, + fd_interm_tensor=fd_inter_tensor, + dtype=self.dtype, + device=device, + ) + self.prefill_bb = BatchBucket( num_heads=model_config.num_attention_heads, head_dim=head_dim, - is_prompts=True, - device=device, + max_batch_size=self.max_batch_size, + max_length=inference_config.max_input_len + inference_config.max_output_len, + block_size=inference_config.block_size, + kv_max_split_num=kv_max_split_num, + fd_interm_tensor=fd_inter_tensor, dtype=self.dtype, - fd_inter_tensor=fd_inter_tensor, + device=device, ) def _init_cache(self, model_config): @@ -159,40 +196,39 @@ class RequestHandler: remove_list.append(seq) break - # stop feeding new sequence into running list to assure - if self.cache_manager.num_available_blocks <= self.running_list.total_seq_num(): - break + num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num) + remove_list.extend(lst[:num_seqs_to_add]) + self.running_list.extend(lst[:num_seqs_to_add]) - # Try to allocate cache blocks for the sequence. - if ( - self.cache_manager.check_allocation(seq) - and (len(self.running_list.prefill) + len(self.running_list.decoding)) - < self.max_batch_size # There some bugs in continous batching, so we disable it here. - ): - # If succeed, add the sequence to running list. - remove_list.append(seq) - self.running_list.append(seq) - self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.sentence_len) for seq in remove_list: lst.remove(seq) if self.running_list.ready_for_prefill(): - for seq in self.running_list.prefill: - seq.mark_running() - self.prefill_batch.add_seqs(self.running_list.prefill) - return self.prefill_batch + num_seqs_to_add = min(self.running_list.prefill_seq_num, self.running_bb.available_batch_size) - if not self.running_batch.is_empty: - for seq in self.running_batch.sequences_set: - recycle = self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len) - if recycle: + for seq in self.running_list.prefill[:num_seqs_to_add]: + seq.mark_running() + # allocate blocks for the prefill batch + self.prefill_bb.add_seqs( + self.running_list.prefill[:num_seqs_to_add], + alloc_block_tables_fn=self.cache_manager.allocate_context_from_block_tables, + ) + + return self.prefill_bb + + if not self.running_bb.is_empty: + seqs_ids_to_recycle = self.cache_manager.allocate_tokens_from_block_tables( + self.running_bb.block_tables, self.running_bb.seq_lengths, self.running_bb.current_batch_size + ) + if seqs_ids_to_recycle: + seqs_to_recycle = self.running_bb.pop_seqs(seqs_ids_to_recycle) + for seq in seqs_to_recycle: seq.recycle() - self.running_batch.del_seq(seq) self.running_list.remove(seq) self.waiting_list[-1].append(seq) # the recycled sequences are handled with highest priority. - return self.running_batch + return self.running_bb def add_sequence(self, req: Sequence): """ @@ -213,7 +249,7 @@ class RequestHandler: seq.mark_aborted() self.waiting_list[priority].remove(seq) elif seq.status.is_running(): - self.cache_manager.free_block_table(seq.block_table) + self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table) self.running_list.remove(seq) else: try: @@ -242,7 +278,7 @@ class RequestHandler: else: sample_tokens = greedy_sample(generation_config, logprobs) else: - sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_batch.is_empty) + sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_bb.is_empty) return sample_tokens @@ -273,27 +309,25 @@ class RequestHandler: # sample the next tokens sample_tokens = self._sample(probs, logprobs, generation_config) - if not self.prefill_batch.is_empty: - self.prefill_batch.update_batch_tokens(sample_tokens) + if not self.prefill_bb.is_empty: + self.prefill_bb.append_batch_tokens(sample_tokens) else: - self.running_batch.update_batch_tokens(sample_tokens) + self.running_bb.append_batch_tokens(sample_tokens) def update(self): """ Update current running list and done list """ - if not self.prefill_batch.is_empty: - self.running_list.decoding.extend(self.running_list.prefill) - self.running_batch.add_seqs(self.running_list.prefill) - self.running_list.prefill.clear() - self.prefill_batch.clear_batch() - - finish_seqs = self.running_batch.fliter_batch() - - for seq in finish_seqs: + if not self.prefill_bb.is_empty: + self.running_list.move_prefill_to_decoding(self.prefill_bb.seqs_ids) + self.running_bb.merge(self.prefill_bb) + # clear the prefill batch without assigning a free_block_tables_fn + # since we want to reuse the memory recorded on the block tables + self.prefill_bb.clear(free_block_tables_fn=None) + + finished_seqs, _ = self.running_bb.pop_finished(self.cache_manager.free_block_table) + for seq in finished_seqs: self.running_list.remove(seq) - self.cache_manager.free_block_table(seq.block_table) - - self.done_list.extend(finish_seqs) + self.done_list.extend(finished_seqs) - return finish_seqs + return finished_seqs diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index d16ced8e9..7d435d59c 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -63,7 +63,6 @@ class KVCacheManager: self.dtype = config.dtype self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size() self.num_layers = get_model_config_attr(model_config, "num_hidden_layers") - # For now we focus on MHA only, TODO add handling for MQA and GQA self.head_num = get_model_config_attr(model_config, "num_attention_heads") self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" @@ -82,8 +81,8 @@ class KVCacheManager: # Physical cache allocation alloc_shape = (self.num_blocks, self.head_num, self.block_size, self.head_size) - if verbose: - self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.") + # if verbose: + # self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.") self._kv_caches = self._init_device_caches(alloc_shape) self.total_physical_cache_size_in_bytes = ( self.elem_size_in_bytes @@ -112,6 +111,9 @@ class KVCacheManager: """Get the number of available cache blocks.""" return self._available_blocks + def get_head_size(self): + return self.head_size + def get_kv_cache(self): """Get k_cache and v_cache""" return self._kv_caches @@ -148,7 +150,7 @@ class KVCacheManager: and updates the provided block table with the allocated block ids. Args: - block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id. + block_table: A 1D tensor of shape [max_blocks_per_sequence], mapping of token_position_id -> block_id. context_len: The length of the processing sequnece. """ assert block_table.dim() == 1 @@ -193,12 +195,85 @@ class KVCacheManager: else: self._allocate_on_block(block, block.block_size) + def allocate_context_from_block_tables(self, block_tables: torch.Tensor, context_lengths: torch.Tensor) -> None: + """Allocate logical cache blocks for a batch of sequences during prefill stage. + + Args: + block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] + context_lengths (torch.Tensor): [bsz]] + """ + assert block_tables.dim() == 2 + assert block_tables.size(0) == context_lengths.size(0) + if not torch.all(block_tables < 0): + self.logger.error("Some slots on provided block table have been allocated.") + blocks_required = (context_lengths + self.block_size - 1) // self.block_size + num_blocks_required = torch.sum(blocks_required).item() + assert isinstance(num_blocks_required, int) + if num_blocks_required > self._available_blocks: + self.logger.warning( + f"Lacking blocks to allocate. Available blocks {self._available_blocks}; blocks asked {num_blocks_required}." + ) + return + + bsz = block_tables.size(0) + # Try contiguous allocation + torch.cumsum(self._block_states, dim=-1, out=self._block_states_cum[1:]) + torch.subtract( + self._block_states_cum[num_blocks_required:], + self._block_states_cum[:-num_blocks_required], + out=self._block_finder[num_blocks_required - 1 :], + ) + end_indexes = torch.nonzero(self._block_finder == num_blocks_required, as_tuple=False).view(-1) + if end_indexes.numel() > 0: + # contiguous cache exists + end_idx = end_indexes[0].item() + 1 # open interval + start_idx = end_idx - num_blocks_required # closed interval + alloc_block_ids = torch.arange(start_idx, end_idx) + for i in range(bsz): + curr_required = blocks_required[i] + block_tables[i, :curr_required] = torch.arange( + start_idx, start_idx + curr_required, device=block_tables.device + ) + start_idx += curr_required + else: + # non-contiguous cache + available_block_ids = torch.nonzero(self._block_states > 0).view(-1) + alloc_block_ids = available_block_ids[:num_blocks_required] + alloc_block_ids = alloc_block_ids.to(dtype=block_tables.dtype, device=block_tables.device) + start_idx = 0 + for i in range(bsz): + curr_required = blocks_required[i] + block_tables[i, :curr_required] = alloc_block_ids[start_idx, start_idx + curr_required] + start_idx += curr_required + + # Update cache blocks + self._block_states[alloc_block_ids] = 0 + self._available_blocks -= num_blocks_required + last_block_locs = torch.cumsum(blocks_required, dim=0) - 1 + last_block_locs = last_block_locs.to(device=alloc_block_ids.device) + + for i, block_id in enumerate(alloc_block_ids[last_block_locs]): + block: CacheBlock = self._cache_blocks[block_id] + block.add_ref() + self._allocate_on_block( + block, + block.block_size + if context_lengths[i] % block.block_size == 0 + else context_lengths[i].item() % block.block_size, + ) + for block_id in alloc_block_ids: + if block_id in alloc_block_ids[last_block_locs]: + continue + block: CacheBlock = self._cache_blocks[block_id] + block.add_ref() + self._allocate_on_block(block, block.block_size) + def allocate_token_from_block_table(self, block_table: torch.Tensor, context_len: int) -> None: """Allocate the logical cache block for a single sequence during decoding stage, and updates the provided block table if a new cache block is needed. Args: - block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id. + block_table: A 1D tensor of shape [max_blocks_per_sequence], mapping of token_position_id -> block_id. context_len: The length of the processing sequnece (already-allocated length). """ assert block_table.dim() == 1 @@ -207,12 +282,79 @@ class KVCacheManager: alloc_local_block_idx = context_len // self.block_size return self.allocate_single_block(block_table, alloc_local_block_idx) + def allocate_tokens_from_block_tables( + self, block_tables: torch.Tensor, context_lens: torch.Tensor, bsz: int = None + ) -> List[int]: + """Allocate logical cache blocks for a batch of sequences during decoding stage. + + Usage: + allocate_context_from_block_tables + model forward (block tables & context lengths passed) + update context lengths + allocate_tokens_from_block_tables + model forward + update context lengths + allocate_tokens_from_block_tables + model forward + update context lengths + ... + + Args: + block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] + context_lengths (torch.Tensor): [bsz] + + Returns: + List[int]: list of sequence uid to be recycled + """ + assert block_tables.dim() == 2 + assert context_lens.dim() == 1 + + bsz = block_tables.size(0) if bsz is None else bsz + + alloc_local_block_indexes = (context_lens[:bsz]) // self.block_size + block_global_ids = block_tables[torch.arange(0, bsz), alloc_local_block_indexes] + seqs_to_recycle = [] + new_blocks_required = torch.sum(block_global_ids < 0).item() + seqs_req_new_blocks = torch.nonzero(block_global_ids < 0).squeeze() + + if new_blocks_required > 0: + if new_blocks_required > self._available_blocks: + # TODO might want to revise the logic here + # Process the first (_available_blocks) sequences that require new blocks + # Put the rest of the sequences back to recycled + seqs_req_new_blocks, seqs_to_recycle = ( + seqs_req_new_blocks[: self._available_blocks], + seqs_req_new_blocks[self._available_blocks :], + ) + for seq_id in seqs_to_recycle: + self.free_block_table(block_tables[seq_id]) + new_blocks_required = self._available_blocks + + # NOTE might want to alloc contiguous logic + free_block_ids = torch.nonzero(self._block_states > 0).view(-1) + alloc_block_ids = free_block_ids[:new_blocks_required].to( + dtype=block_tables.dtype, device=block_tables.device + ) + + for block_id in alloc_block_ids: + block: CacheBlock = self._cache_blocks[block_id] + block.add_ref() + self._block_states[block_id] = 0 + self._available_blocks -= 1 + block_tables[seqs_req_new_blocks, alloc_local_block_indexes[seqs_req_new_blocks]] = alloc_block_ids + block_global_ids = block_tables[torch.arange(0, bsz), alloc_local_block_indexes] + + for block_id in block_global_ids: + self._allocate_on_block(self._cache_blocks[block_id], 1) + + return seqs_to_recycle + def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int) -> int: """Allocate space asked on a single block in the block table, specified by the provided position id, and updates the provided block table with the allocated block. Args: - block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id. + block_table: A 1D tensor of shape [max_blocks_per_sequence], mapping of token_position_id -> block_id. block_local_idx: The index of the block in the block table. space_asked: i.e. The number of tokens to be assigned space for. Returns: @@ -240,8 +382,7 @@ class KVCacheManager: def free_block_table(self, block_table: torch.Tensor) -> None: """Free the logical cache blocks for **a single sequence**.""" assert block_table.dim() == 1 - for i in range(block_table.numel()): - global_block_id = block_table[i].item() + for i, global_block_id in enumerate(block_table.tolist()): if global_block_id < 0: return block: CacheBlock = self._cache_blocks[global_block_id] @@ -253,6 +394,15 @@ class KVCacheManager: # reset the block id in the block table (if we maintain a 2D tensors as block tables in Engine) block_table[i] = -1 + def free_block_tables(self, block_tables: torch.Tensor, first_n: int = None) -> None: + """Release the logical cache blocks for a batch of sequences. + If `first_n` is provided, only the blocks for the first several sequences will be released. + """ + assert block_tables.dim() == 2 + first_n = block_tables.size(0) if first_n is None else first_n + for block_table in block_tables[:first_n]: + self.free_block_table(block_table) + def clear_all(self) -> None: """Clear all the references and allocations on all the cache blocks.""" for block in self._cache_blocks: diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index a1db4ecfa..6b6a5876b 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -12,8 +12,8 @@ from transformers.models.llama.modeling_llama import ( LlamaModel, ) +from colossalai.inference.batch_bucket import BatchBucket from colossalai.inference.flash_decoding_utils import FDIntermTensors -from colossalai.inference.struct import BatchInfo from colossalai.kernel.triton import ( context_attention_unpadded, copy_kv_to_blocked_cache, @@ -34,7 +34,7 @@ except ImportError: def llama_causal_lm_forward( self: LlamaForCausalLM, - batch: BatchInfo = None, + batch: BatchBucket = None, k_caches: List[torch.Tensor] = None, v_caches: List[torch.Tensor] = None, ): @@ -59,7 +59,7 @@ def llama_causal_lm_forward( def llama_model_forward( self: LlamaModel, - batch: BatchInfo = None, + batch: BatchBucket = None, k_caches: List[torch.Tensor] = None, v_caches: List[torch.Tensor] = None, ): @@ -73,7 +73,7 @@ def llama_model_forward( input_ids = batch.get_1D_inputs() block_tables = batch.get_block_table_tensor() sequence_lengths = batch.get_sequence_lengths() - batch_size = len(sequence_lengths) + batch_size = batch.current_batch_size kv_seq_len = sequence_lengths.max().item() hidden_states = self.embed_tokens(input_ids) diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index 766e54ab1..706304038 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -71,7 +71,6 @@ class Sequence: input_token_id: List[int] block_size: int sample_params: Any # SampleParams needs to be imported later. - block_table: torch.Tensor eos_token_id: int pad_token_id: int max_output_len: int = 256 @@ -158,7 +157,6 @@ class Sequence: f"prompt={self.prompt}, " f"status={self.status.name}, " f"sample_params={self.sample_params}, " - f"logical_block_number={self.block_table.shape[0]}," f"input_len={self.input_len})," f"output_len={self.output_len})" ) diff --git a/tests/test_infer/test_batch_bucket.py b/tests/test_infer/test_batch_bucket.py new file mode 100644 index 000000000..e2d5774f4 --- /dev/null +++ b/tests/test_infer/test_batch_bucket.py @@ -0,0 +1,140 @@ +import torch +from transformers.models.llama import LlamaConfig + +from colossalai.inference.batch_bucket import BatchBucket +from colossalai.inference.config import InferenceConfig +from colossalai.inference.kv_cache import KVCacheManager +from colossalai.inference.struct import Sequence +from colossalai.testing import parameterize + + +@parameterize( + "test_config", + [ + { + "hidden_size": 128, + "num_attention_heads": 4, + "num_layers": 2, + "block_size": 4, + "max_batch_size": 4, + "max_input_len": 32, + "max_output_len": 8, + "dtype": torch.float16, + "tp_size": 1, + } + ], +) +def test_bucket(test_config): + hidden_size = test_config.pop("hidden_size") + num_heads = test_config.pop("num_attention_heads") + num_layers = test_config.pop("num_layers") + model_config = LlamaConfig( + hidden_size=hidden_size, + num_hidden_layers=num_layers, + num_attention_heads=num_heads, + ) + inference_config = InferenceConfig(**test_config) + + # Just for testing usage. Don't create multiple cache_manager on the same device. + cache_manager = KVCacheManager(inference_config, model_config) + cache_manager_copy = KVCacheManager(inference_config, model_config) + + seq_lens = [19, 20, 27] + seq1 = Sequence( + request_id=0, + prompt="", # Dummy for testing usage + input_token_id=list(range(seq_lens[0])), + block_size=4, + sample_params=None, + eos_token_id=2, + pad_token_id=2, + max_output_len=10, + ) + seq2 = Sequence( + request_id=1, + prompt="", # Dummy for testing usage + input_token_id=list(range(seq_lens[1])), + block_size=4, + sample_params=None, + eos_token_id=2, + pad_token_id=2, + max_output_len=10, + ) + seq3 = Sequence( + request_id=2, + prompt="", # Dummy for testing usage + input_token_id=list(range(seq_lens[2])), + block_size=4, + sample_params=None, + eos_token_id=2, + pad_token_id=2, + max_output_len=10, + ) + + block_size = test_config["block_size"] + max_batch_size = test_config["max_batch_size"] + max_length = test_config["max_input_len"] + test_config["max_output_len"] + assert max_batch_size >= 2, "max_batch_size should be greater than 1" + + bb = BatchBucket( + num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2 + ) + bb_copy = BatchBucket( + num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2 + ) + block_tables = bb.add_seqs([seq1, seq2]) + assert block_tables.shape == (2, cache_manager.max_blocks_per_sequence) + assert torch.all(block_tables < 0), "Initialized block_tables should be negative values" + + cache_manager.allocate_context_from_block_tables(block_tables, bb.seq_lengths[: bb.current_batch_size]) + bb_copy.add_seqs( + [seq1, seq2], alloc_block_tables_fn=cache_manager_copy.allocate_context_from_block_tables + ) # This is just for testing usage. Don't add the same sequence to different buckets. + + assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * ( + max_batch_size - bb.current_batch_size + ) + assert torch.equal(bb.block_tables, bb_copy.block_tables) + + bb.append_batch_tokens(torch.tensor([99, 99])) + assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * ( + max_batch_size - bb.current_batch_size + ) + + cache_manager.allocate_tokens_from_block_tables(bb.block_tables, bb.seq_lengths, bsz=bb.current_batch_size) + assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * ( + max_batch_size - bb.current_batch_size + ) + + bb.append_batch_tokens(torch.tensor([99, 99])) + + cache_manager.allocate_tokens_from_block_tables(bb.block_tables, bb.seq_lengths, bsz=bb.current_batch_size) + assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * ( + max_batch_size - bb.current_batch_size + ) + + bb.pop_seq_update_batch(0, free_block_table_fn=cache_manager.free_block_table) + assert bb.seq_lengths.tolist() == [bb.seqs_li[0].sentence_len] + [0] * (max_batch_size - bb.current_batch_size) + assert bb.is_compact + + bb2 = BatchBucket( + num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2 + ) + block_tables = bb2.add_seqs([seq3]) + cache_manager.allocate_context_from_block_tables(block_tables, bb2.seq_lengths[: bb2.current_batch_size]) + unmerged_ids = bb.merge(bb2) + assert not unmerged_ids + assert bb.is_compact + assert bb2.is_compact + assert bb.current_batch_size == 2 + assert bb2.current_batch_size == 0 + + bb.clear(cache_manager.free_block_tables) + assert bb.current_batch_size == 0 + assert bb.is_compact + assert bb.seq_lengths.tolist() == [0] * max_batch_size + assert torch.all(bb.block_tables < 0) + + +if __name__ == "__main__": + test_bucket() diff --git a/tests/test_infer/test_config_and_struct.py b/tests/test_infer/test_config_and_struct.py index 47d3839e4..046ee932d 100755 --- a/tests/test_infer/test_config_and_struct.py +++ b/tests/test_infer/test_config_and_struct.py @@ -15,7 +15,6 @@ def check_config_and_inference(): input_token_id=[1, 2, 3], block_size=16, sample_params=None, - block_table=None, eos_token_id=2, pad_token_id=2, max_output_len=256, @@ -27,7 +26,6 @@ def check_config_and_inference(): input_token_id=[4, 5, 6], block_size=16, sample_params=None, - block_table=None, eos_token_id=2, pad_token_id=2, max_output_len=256, @@ -39,7 +37,6 @@ def check_config_and_inference(): input_token_id=[7, 8, 9], block_size=16, sample_params=None, - block_table=None, eos_token_id=2, pad_token_id=2, max_output_len=256, diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py index a2051f220..321047706 100755 --- a/tests/test_infer/test_kvcache_manager.py +++ b/tests/test_infer/test_kvcache_manager.py @@ -148,6 +148,20 @@ def check_cache_manager(test_config): cache_manager.clear_all() assert cache_manager.num_available_blocks == num_blocks + for cache_block in cache_manager._cache_blocks: + assert cache_block.available_space == block_size + + # Mock batch operations (Prefill/Decoding updates) + context_lengths = torch.tensor([max_input_length, max_input_length - 1]) + block_tables = torch.tensor( + [[-1 for _ in range(cache_manager.max_blocks_per_sequence)] for _ in range(2)], dtype=torch.int32 + ) + cache_manager.allocate_context_from_block_tables(block_tables, context_lengths) + cache_manager.allocate_tokens_from_block_tables(block_tables, context_lengths) + cache_manager.free_block_tables(block_tables) + for cache_block in cache_manager._cache_blocks: + assert cache_block.available_space == block_size + def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") diff --git a/tests/test_infer/test_request_handler.py b/tests/test_infer/test_request_handler.py index d589e9717..c7a35ebbe 100644 --- a/tests/test_infer/test_request_handler.py +++ b/tests/test_infer/test_request_handler.py @@ -1,5 +1,4 @@ import pytest -import torch from transformers.models.llama import LlamaConfig import colossalai @@ -22,17 +21,35 @@ def check_running_list(): eos_token_id=0, pad_token_id=0, sample_params=None, - block_table=1, ) - + seq2 = Sequence( + request_id=2, + prompt="abc", + input_token_id=[1, 2, 3], + block_size=16, + eos_token_id=0, + pad_token_id=0, + sample_params=None, + ) running_list.append(seq1) + running_list.append(seq2) assert running_list.ready_for_prefill() - assert running_list.decoding == [] and running_list.prefill[0] == seq1 + assert len(running_list.decoding) == 0 + assert len(running_list.prefill) > 0 and running_list.prefill[0] == seq1 seq = running_list.find_seq(seq1.request_id) assert seq == seq1 + running_list.mark_prefill_running() + for seq in running_list.prefill: + assert seq.status == RequestStatus.RUNNING + + running_list.move_prefill_to_decoding([seq1.request_id, seq2.request_id]) + assert len(running_list.prefill) == 0 + assert len(running_list.decoding) > 0 and running_list.decoding[0] == seq1 + running_list.remove(seq1) + running_list.remove(seq2) assert running_list.is_empty() @@ -59,7 +76,6 @@ def check_request_handler(): eos_token_id=0, pad_token_id=0, sample_params=None, - block_table=torch.tensor([-1, -1]), ) request_handler.add_sequence(seq1) # the priority should be 1