mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* add kvcache manager funcs for batching * add batch bucket for batching * revise RunningList struct in handler * add kvcache/batch funcs for compatibility * use new batching methods * fix indexing bugs * revise abort logic * use cpu seq lengths/block tables * rm unused attr in Sequence * fix type conversion/default arg * add and revise pytests * revise pytests, rm unused tests * rm unused statements * fix pop finished indexing issue * fix: use index in batch when retrieving inputs/update seqs * use dict instead of odict in batch struct * arg type hinting * fix make compress * refine comments * fix: pop_n_seqs to pop the first n seqs * add check in request handler * remove redundant conversion * fix test for request handler * fix pop method in batch bucket * fix prefill addingpull/5399/head
Yuanheng Zhao
9 months ago
committed by
GitHub
11 changed files with 905 additions and 115 deletions
@ -0,0 +1,449 @@
|
||||
from typing import Callable, List, Optional, Tuple, Union |
||||
|
||||
import torch |
||||
|
||||
from colossalai.inference.struct import Sequence |
||||
from colossalai.utils import get_current_device |
||||
|
||||
|
||||
class BatchBucket: |
||||
"""Container for a batch of Sequences, which is used to manage the batch of sequences. |
||||
|
||||
Attrs: |
||||
_sequences_dict (Dict[int, Sequence]): Map sequence uid to sequence struct |
||||
seq_uid -> Sequence |
||||
_sequences_indexes (Dict[int, int]): Map sequence uid to index in the batch |
||||
seq_uid -> index in the batch (indexing used in sequence_lengths and block_tables) |
||||
_sequence_lengths (torch.Tensor): Length of each sequence in the batch. |
||||
The size of the tensor is (max_batch_size,) |
||||
_block_tables (torch.Tensor): Block table of each sequence in the batch |
||||
The size of the tensor is (max_batch_size, max_blocks_per_seq) |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
num_heads, |
||||
head_dim, |
||||
max_batch_size, |
||||
max_length, |
||||
block_size, |
||||
kv_max_split_num, |
||||
fd_interm_tensor=None, |
||||
device=None, |
||||
dtype=torch.float16, |
||||
): |
||||
self.num_heads = num_heads |
||||
self.head_dim = head_dim |
||||
self.max_batch_size = max_batch_size |
||||
self.max_length = max_length # in + out len |
||||
self.block_size = block_size |
||||
self.kv_max_split_num = kv_max_split_num # Hint used for flash decoding |
||||
self.fd_interm_tensor = fd_interm_tensor |
||||
self.device = device or get_current_device() |
||||
self.dtype = dtype |
||||
|
||||
self._current_batch_size = 0 |
||||
self._sequences_dict = dict() |
||||
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size) |
||||
self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32) |
||||
self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths) |
||||
max_blocks_per_seq = (self.max_length + block_size - 1) // block_size |
||||
self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32) |
||||
self._block_tables_helper = torch.full_like(self._block_tables, -1) |
||||
|
||||
@property |
||||
def is_empty(self): |
||||
return self._current_batch_size == 0 |
||||
|
||||
@property |
||||
def current_batch_size(self): |
||||
return self._current_batch_size |
||||
|
||||
@property |
||||
def available_batch_size(self): |
||||
return self.max_batch_size - self._current_batch_size |
||||
|
||||
@property |
||||
def block_tables(self): |
||||
return self._block_tables |
||||
|
||||
@property |
||||
def seq_lengths(self): |
||||
return self._sequence_lengths |
||||
|
||||
@property |
||||
def seqs_ids(self): |
||||
return list(self._sequences_dict.keys()) |
||||
|
||||
@property |
||||
def seqs_li(self): |
||||
return list(self._sequences_dict.values()) |
||||
|
||||
@property |
||||
def is_compact(self): |
||||
assert len(self._sequences_dict) == len(self._sequences_indexes), "BatchBucket indexing is not consistent" |
||||
return ( |
||||
len(self._sequences_dict) |
||||
== torch.nonzero(self._sequence_lengths).view(-1).numel() |
||||
== torch.nonzero(self._block_tables[:, 0] >= 0).numel() |
||||
) |
||||
|
||||
def _make_compact(self) -> None: |
||||
# Clean and Compress the batch based on its sequences dict. |
||||
# Namely,compress sequences to the front and clean the seq lengths and block tables tensors. |
||||
# NOTE Prevent calling this method multiple times in a single step |
||||
if self.is_compact: |
||||
return |
||||
valid_seq_ids = self._sequences_dict.keys() |
||||
valid_num = len(valid_seq_ids) |
||||
valid_indexes = [self._sequences_indexes[seq_id] for seq_id in valid_seq_ids] |
||||
assert valid_num == len(self._sequences_indexes), "BatchBucket indexing is not consistent" |
||||
self._sequence_lengths_helper[:valid_num] = self._sequence_lengths[valid_indexes] |
||||
self._sequence_lengths[:] = self._sequence_lengths_helper[:] |
||||
self._block_tables_helper[:valid_num, :] = self.block_tables[valid_indexes] |
||||
self.block_tables[:] = self._block_tables_helper[:] |
||||
new_idx = 0 |
||||
for seq_id in valid_seq_ids: |
||||
self._sequences_indexes[seq_id] = new_idx |
||||
new_idx += 1 |
||||
self._sequence_lengths_helper.fill_(0) |
||||
self._block_tables_helper.fill_(-1) |
||||
self._current_batch_size = valid_num |
||||
|
||||
def add_seq( |
||||
self, |
||||
seq: Sequence, |
||||
alloc_block_table: torch.Tensor = None, |
||||
alloc_block_table_fn: Callable[[torch.Tensor, int], None] = None, |
||||
) -> Union[torch.Tensor, None]: |
||||
"""Add a single sequence to the batch. |
||||
User could opt to provide either a block table or a function to allocate block tables. |
||||
|
||||
Args: |
||||
seq (Sequence): The sequence to be added to the batch |
||||
alloc_block_table (torch.Tensor): The block tables to be copied and used for the sequence |
||||
alloc_block_table_fn (Callable[[torch.Tensor, int], None]): The function to allocate blocks for the sequence, |
||||
which is expected to reserve blocks and update status of kv-cache manager. |
||||
|
||||
Returns: |
||||
block_table (torch.Tensor): The block table of the added sequence, used for block allocation in kv-cache manager. |
||||
None if the sequence cannot be added. |
||||
""" |
||||
block_table = None |
||||
# TODO might consider sorting by length |
||||
if self._current_batch_size < self.max_batch_size: |
||||
self._sequences_dict[seq.request_id] = seq |
||||
self._sequences_indexes[seq.request_id] = self._current_batch_size |
||||
self._sequence_lengths[self._current_batch_size] = seq.sentence_len |
||||
# NOTE the added seq still require block table allocation by kvcache manager |
||||
block_table = self._block_tables[self._current_batch_size - 1] |
||||
if alloc_block_table is not None: |
||||
# copy block ids from provided block tables |
||||
self._block_tables[self._current_batch_size - 1] = alloc_block_table |
||||
elif alloc_block_table_fn: |
||||
alloc_block_table_fn(block_table, self._sequence_lengths[self._current_batch_size - 1].item()) |
||||
self._current_batch_size += 1 |
||||
return block_table |
||||
|
||||
def add_seqs( |
||||
self, |
||||
seqs: List[Sequence], |
||||
alloc_block_tables: torch.Tensor = None, |
||||
alloc_block_tables_fn: Callable[[torch.Tensor, torch.Tensor], None] = None, |
||||
) -> Union[torch.Tensor, None]: |
||||
"""Add a list of sequences to the batch. |
||||
User could opt to provide either block tables or a function to allocate block tables. |
||||
|
||||
Args: |
||||
seqs (List[Sequence]): The sequences to be added to the batch |
||||
alloc_block_tables (torch.Tensor): The block tables to be copied and used for the sequence |
||||
alloc_block_table_fn (Callable[[torch.Tensor, torch.Tensor], None]): The function to allocate blocks for multiple sequences, |
||||
which is expected to reserve blocks and update status of kv-cache manager. |
||||
|
||||
Returns: |
||||
block_tables (torch.Tensor): The block tables of the added sequences, used for block allocation in kv-cache manager. |
||||
None if the sequences cannot be added. |
||||
""" |
||||
|
||||
assert ( |
||||
alloc_block_tables is None or alloc_block_tables_fn is None |
||||
), "`alloc_block_tables` and `alloc_block_tables_fn` cannot be provided at the same time" |
||||
|
||||
num_seqs_to_add = min(self.max_batch_size - self._current_batch_size, len(seqs)) |
||||
block_tables = None |
||||
if num_seqs_to_add > 0: |
||||
for i, seq in enumerate(seqs[:num_seqs_to_add]): |
||||
self._sequences_dict[seq.request_id] = seq |
||||
self._sequences_indexes[seq.request_id] = self._current_batch_size + i |
||||
# TODO external (rename): modify Sequence.sentence_len to seq_len |
||||
self._sequence_lengths[ |
||||
self._current_batch_size : self._current_batch_size + num_seqs_to_add |
||||
] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32) |
||||
# NOTE block tables to be updated by kvcache manager |
||||
block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] |
||||
if alloc_block_tables is not None: |
||||
# copy block ids from provided block tables |
||||
self._block_tables[ |
||||
self._current_batch_size : self._current_batch_size + num_seqs_to_add |
||||
] = alloc_block_tables |
||||
elif alloc_block_tables_fn: |
||||
alloc_block_tables_fn( |
||||
block_tables, |
||||
self._sequence_lengths[self._current_batch_size : self._current_batch_size + num_seqs_to_add], |
||||
) |
||||
|
||||
self._current_batch_size += num_seqs_to_add |
||||
seqs[:] = seqs[num_seqs_to_add:] |
||||
|
||||
return block_tables |
||||
|
||||
def pop_seq_update_batch( |
||||
self, request_id: int, free_block_table_fn: Callable[[torch.Tensor], None] = None |
||||
) -> Tuple[Sequence, Union[torch.Tensor, None]]: |
||||
"""Pop a single sequence by id from the batch, and update the batch bucket status. |
||||
|
||||
Args: |
||||
request_id (int): The uid of the sequence |
||||
free_block_table_fn (Callable): The function to free the block table of a sequence, |
||||
if not provided, then we have to release the block table manually after calling this method |
||||
|
||||
Returns: |
||||
A tuple of: seq (Sequence): The target sequence |
||||
and block_table (torch.Tensor): block table of the target sequence indicating corresponding blocks, |
||||
none if the sequence is not found or free_block_table_fn is provided. |
||||
""" |
||||
seq: Sequence = self._sequences_dict.get(request_id) |
||||
block_table = None |
||||
if seq is not None: |
||||
assert request_id in self._sequences_indexes, "Inconsistency in BatchBucket indexing" |
||||
self._sequences_dict.pop(request_id) |
||||
seq_b_idx = self._sequences_indexes.get(request_id) |
||||
|
||||
if self.current_batch_size > 1: |
||||
# replace seq length of the target seq with that of the last seq in the batch |
||||
last_seq_b_idx = self.current_batch_size - 1 |
||||
last_seq_id = next( |
||||
(uid for uid, index in self._sequences_indexes.items() if index == last_seq_b_idx), |
||||
None, |
||||
) |
||||
assert last_seq_id is not None |
||||
self._sequences_indexes[last_seq_id] = seq_b_idx |
||||
self._sequence_lengths[seq_b_idx] = self._sequence_lengths[last_seq_b_idx] |
||||
self._sequence_lengths[last_seq_b_idx].fill_(0) |
||||
# free the block table of the seq, or return a copy of the block table (to be processed outside) |
||||
if free_block_table_fn: |
||||
free_block_table_fn(self._block_tables[seq_b_idx]) |
||||
else: |
||||
block_table = self._block_tables[seq_b_idx].detach().clone() |
||||
# replace block table of the target seq with that of the last seq in the batch |
||||
self._block_tables[seq_b_idx] = self._block_tables[last_seq_b_idx] |
||||
self._block_tables[last_seq_b_idx].fill_(-1) |
||||
else: |
||||
if free_block_table_fn: |
||||
free_block_table_fn(self._block_tables[0]) |
||||
else: |
||||
block_table = self._block_tables[0].detach().clone() |
||||
self._sequence_lengths[0].fill_(0) |
||||
self._block_tables[0].fill_(-1) |
||||
self._sequences_indexes.pop(request_id) |
||||
self._current_batch_size -= 1 |
||||
|
||||
return seq, block_table |
||||
|
||||
def pop_seqs( |
||||
self, request_ids: List[int], free_block_table_fn: Callable[[torch.Tensor], None] = None |
||||
) -> Tuple[List[Sequence], List[torch.Tensor]]: |
||||
"""Iteratively pop a list of sequences by uid. |
||||
|
||||
Args: |
||||
request_ids (List[int]): The uids of the sequences |
||||
free_block_table_fn (Callable): The function to free the block table of a sequence, |
||||
if not provided, then we have to release the block table manually after calling this method |
||||
Returns: |
||||
A tuple of: seqs (List[Sequence]): The target sequences |
||||
and block_tables (List[torch.Tensor]): block tables of the target sequences indicating corresponding blocks |
||||
""" |
||||
seqs = [] |
||||
block_tables = [] |
||||
for request_id in request_ids: |
||||
seq, block_table = self.pop_seq_update_batch(request_id, free_block_table_fn) |
||||
if seq is not None: |
||||
seqs.append(seq) |
||||
if block_table is not None: |
||||
block_tables.append(block_table) |
||||
return seqs, block_tables |
||||
|
||||
def pop_n_seqs( |
||||
self, n: int, free_block_table_fn: Callable[[torch.Tensor], None] = None |
||||
) -> Tuple[List[Sequence], List[torch.Tensor]]: |
||||
"""Pop the first n sequences in the batch (FIFO). |
||||
If n is greater than the current batch szie, pop all the sequences in the batch. |
||||
|
||||
Args: |
||||
n (int): The number of sequences to pop out |
||||
free_block_table_fn (Callable): The function to free the block table of a single sequence |
||||
Returns: |
||||
A tuple of: seqs (List[Sequence]): The target sequences, |
||||
and block_tables (List[torch.Tensor]): block tables of the target sequences indicating corresponding blocks |
||||
""" |
||||
# NOTE Prevent calling this method multiple times in a single step |
||||
seqs = [] |
||||
block_tables = [] |
||||
n = min(n, self.current_batch_size) |
||||
seq_ids = list(self._sequences_dict.keys())[:n] |
||||
for seq_id in seq_ids: |
||||
seq = self._sequences_dict.pop(seq_id) |
||||
seq_b_idx = self._sequences_indexes.pop(seq_id) |
||||
if free_block_table_fn: |
||||
free_block_table_fn(self.block_tables[seq_b_idx]) |
||||
else: |
||||
block_tables.append(self.block_tables[seq_b_idx].detach().clone()) |
||||
seqs.append(seq) |
||||
if not self.is_compact: |
||||
self._make_compact() |
||||
return seqs, block_tables |
||||
|
||||
def pop_finished( |
||||
self, free_block_table_fn: Callable[[torch.Tensor], None] = None |
||||
) -> Tuple[List[Sequence], List[torch.Tensor]]: |
||||
"""Pop finished sequences in the batch and a list of block tables of the finished sequences, |
||||
if free_block_table_fn is not provided. |
||||
|
||||
Args: |
||||
free_block_table_fn (Callable): The function to free the block table of a single sequence |
||||
Returns: |
||||
A tuple of: finished_seqs (List[Sequence]): The finished sequences, |
||||
and finished_block_tables (List[torch.Tensor]): block tables of the finished sequences. |
||||
""" |
||||
finished_seqs = [] |
||||
finished_block_tables = [] |
||||
for seq in self._sequences_dict.values(): |
||||
if seq.check_finish(): |
||||
finished_seqs.append(seq) |
||||
# Use `pop_seq_update_batch`` to update the batch status for just a few of finished seqs, |
||||
# otherwise, pop seqs directly and then call `_make_compact` to compress the batch. |
||||
# For now, the performance difference is not significant, so we use the frist method to pop seqs. |
||||
# Precise evaluations to be done. |
||||
for seq in finished_seqs: |
||||
_, block_table = self.pop_seq_update_batch(seq.request_id, free_block_table_fn) |
||||
if block_table is not None: |
||||
finished_block_tables.append(block_table) |
||||
|
||||
return finished_seqs, finished_block_tables |
||||
|
||||
# TODO arg type not support beam search sampling yet |
||||
def append_batch_tokens(self, tokens: torch.Tensor) -> None: |
||||
"""Append a batch of tokens to the sequences in the batch""" |
||||
assert self.current_batch_size == tokens.size(0), "Batch size mismatch" |
||||
|
||||
if self.current_batch_size > 0: |
||||
tokens = tokens.tolist() |
||||
for seq_id, seq in self._sequences_dict.items(): |
||||
index_in_b = self._sequences_indexes[seq_id] |
||||
curr_tokens = tokens[index_in_b] |
||||
if not isinstance(curr_tokens, list): |
||||
curr_tokens = [curr_tokens] |
||||
seq.output_token_id += curr_tokens |
||||
seq.check_finish() |
||||
self._sequence_lengths[: self.current_batch_size] += 1 |
||||
|
||||
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]: |
||||
"""Clear all the sequences in the batch. |
||||
|
||||
free_block_tables_fn (Optional[Callable]): The function to free the block tables of all the sequences in a batch |
||||
""" |
||||
seqs = list(self._sequences_dict.values()) |
||||
self._sequences_dict.clear() |
||||
self._sequences_indexes.clear() |
||||
if free_block_tables_fn: |
||||
free_block_tables_fn(self.block_tables, self._current_batch_size) |
||||
self._block_tables.fill_(-1) |
||||
self._sequence_lengths.fill_(0) |
||||
self._current_batch_size = 0 |
||||
return seqs |
||||
|
||||
def merge(self, other: "BatchBucket") -> List[int]: |
||||
"""Merge the sequences in the other batch into the current batch. |
||||
Merge as possible as the current batch can, if it does not have available spaces |
||||
holding all the sequences in the other batch |
||||
|
||||
Usage: |
||||
> New incoming sequence added to prefil batch |
||||
prefill bb curr batch size < prefil_ratio * prefill bb max batch size |
||||
> New incoming sequence added to prefil batch |
||||
prefill bb curr batch size == prefil_ratio * prefill bb max batch size |
||||
> Pause Decoding |
||||
> Prefill |
||||
> Move sequences in prefill bb => decoding bb |
||||
> Put back the out-of-volume sequences into the running pool |
||||
|
||||
Returns: |
||||
unmerged_ids (List[int]): a list of sequence uids that are not merged into the current batch |
||||
""" |
||||
unmerged_ids = [] |
||||
num_seqs_to_merge = min(self.available_batch_size, other.current_batch_size) |
||||
if num_seqs_to_merge > 0: |
||||
seqs, block_tables_li = other.pop_n_seqs(num_seqs_to_merge) |
||||
block_tables = torch.stack(block_tables_li) |
||||
self.add_seqs(seqs, alloc_block_tables=block_tables) |
||||
unmerged_ids = other.seqs_ids |
||||
return unmerged_ids |
||||
|
||||
########## The following methods are expected to be used in modeling ########### |
||||
|
||||
# For compatibility. |
||||
# NOTE: This is an assumption way to determine the stage of the batch. |
||||
@property |
||||
def is_prompts(self) -> bool: |
||||
assert len(self._sequences_dict) > 0, "No sequence in the batch" |
||||
first_seq = next(iter(self._sequences_dict.values())) |
||||
if first_seq.output_len == 0: |
||||
return True |
||||
return False |
||||
|
||||
# For compatibility |
||||
def get_1D_inputs(self) -> torch.Tensor: |
||||
assert len(self._sequences_dict) > 0, "No sequence in the batch" |
||||
first_seq = next(iter(self._sequences_dict.values())) # not exactly the first sequence |
||||
if first_seq.output_len == 0: |
||||
# Assume prefill stage |
||||
assert all( |
||||
seq.output_len == 0 for seq in self._sequences_dict.values() |
||||
), "Sequence stage (Prefill/Decoding) must be the same in the batch" |
||||
out_li = [] |
||||
num_tokens = torch.sum(self._sequence_lengths) |
||||
out = torch.empty([num_tokens], dtype=torch.long) |
||||
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x]) |
||||
for seq_id in seq_ids: |
||||
seq: Sequence = self._sequences_dict[seq_id] |
||||
out_li.extend(seq.input_token_id) |
||||
return torch.tensor(out_li, dtype=torch.long, device=self.device) |
||||
else: |
||||
# Assume decoding stage |
||||
assert all( |
||||
seq.output_len > 0 for seq in self._sequences_dict.values() |
||||
), "Sequence stage (Prefill/Decoding) must be the same in the batch" |
||||
assert self.is_compact, "BatchBucket is not compact" |
||||
out = torch.empty([self.current_batch_size], dtype=torch.long) |
||||
for seq_id, index_in_b in self._sequences_indexes.items(): |
||||
seq: Sequence = self._sequences_dict[seq_id] |
||||
out[index_in_b] = seq.output_token_id[-1] |
||||
return out.to(device=self.device) |
||||
|
||||
# For compatibility |
||||
def get_block_table_tensor(self) -> torch.Tensor: |
||||
assert self.is_compact # Debug usage |
||||
block_table = self.block_tables[: self.current_batch_size] |
||||
return block_table.to(device=self.device) |
||||
|
||||
# For compatibility |
||||
def get_sequence_lengths(self) -> torch.Tensor: |
||||
assert self.is_compact # Debug usage |
||||
sequence_lengths = self.seq_lengths[: self.current_batch_size] |
||||
return sequence_lengths.to(device=self.device) |
||||
|
||||
# For compatibility |
||||
@property |
||||
def fd_inter_tensor(self) -> None: |
||||
assert self.fd_interm_tensor is not None, "fd_interm_tensor is not provided" |
||||
return self.fd_interm_tensor |
@ -0,0 +1,140 @@
|
||||
import torch |
||||
from transformers.models.llama import LlamaConfig |
||||
|
||||
from colossalai.inference.batch_bucket import BatchBucket |
||||
from colossalai.inference.config import InferenceConfig |
||||
from colossalai.inference.kv_cache import KVCacheManager |
||||
from colossalai.inference.struct import Sequence |
||||
from colossalai.testing import parameterize |
||||
|
||||
|
||||
@parameterize( |
||||
"test_config", |
||||
[ |
||||
{ |
||||
"hidden_size": 128, |
||||
"num_attention_heads": 4, |
||||
"num_layers": 2, |
||||
"block_size": 4, |
||||
"max_batch_size": 4, |
||||
"max_input_len": 32, |
||||
"max_output_len": 8, |
||||
"dtype": torch.float16, |
||||
"tp_size": 1, |
||||
} |
||||
], |
||||
) |
||||
def test_bucket(test_config): |
||||
hidden_size = test_config.pop("hidden_size") |
||||
num_heads = test_config.pop("num_attention_heads") |
||||
num_layers = test_config.pop("num_layers") |
||||
model_config = LlamaConfig( |
||||
hidden_size=hidden_size, |
||||
num_hidden_layers=num_layers, |
||||
num_attention_heads=num_heads, |
||||
) |
||||
inference_config = InferenceConfig(**test_config) |
||||
|
||||
# Just for testing usage. Don't create multiple cache_manager on the same device. |
||||
cache_manager = KVCacheManager(inference_config, model_config) |
||||
cache_manager_copy = KVCacheManager(inference_config, model_config) |
||||
|
||||
seq_lens = [19, 20, 27] |
||||
seq1 = Sequence( |
||||
request_id=0, |
||||
prompt="", # Dummy for testing usage |
||||
input_token_id=list(range(seq_lens[0])), |
||||
block_size=4, |
||||
sample_params=None, |
||||
eos_token_id=2, |
||||
pad_token_id=2, |
||||
max_output_len=10, |
||||
) |
||||
seq2 = Sequence( |
||||
request_id=1, |
||||
prompt="", # Dummy for testing usage |
||||
input_token_id=list(range(seq_lens[1])), |
||||
block_size=4, |
||||
sample_params=None, |
||||
eos_token_id=2, |
||||
pad_token_id=2, |
||||
max_output_len=10, |
||||
) |
||||
seq3 = Sequence( |
||||
request_id=2, |
||||
prompt="", # Dummy for testing usage |
||||
input_token_id=list(range(seq_lens[2])), |
||||
block_size=4, |
||||
sample_params=None, |
||||
eos_token_id=2, |
||||
pad_token_id=2, |
||||
max_output_len=10, |
||||
) |
||||
|
||||
block_size = test_config["block_size"] |
||||
max_batch_size = test_config["max_batch_size"] |
||||
max_length = test_config["max_input_len"] + test_config["max_output_len"] |
||||
assert max_batch_size >= 2, "max_batch_size should be greater than 1" |
||||
|
||||
bb = BatchBucket( |
||||
num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2 |
||||
) |
||||
bb_copy = BatchBucket( |
||||
num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2 |
||||
) |
||||
block_tables = bb.add_seqs([seq1, seq2]) |
||||
assert block_tables.shape == (2, cache_manager.max_blocks_per_sequence) |
||||
assert torch.all(block_tables < 0), "Initialized block_tables should be negative values" |
||||
|
||||
cache_manager.allocate_context_from_block_tables(block_tables, bb.seq_lengths[: bb.current_batch_size]) |
||||
bb_copy.add_seqs( |
||||
[seq1, seq2], alloc_block_tables_fn=cache_manager_copy.allocate_context_from_block_tables |
||||
) # This is just for testing usage. Don't add the same sequence to different buckets. |
||||
|
||||
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * ( |
||||
max_batch_size - bb.current_batch_size |
||||
) |
||||
assert torch.equal(bb.block_tables, bb_copy.block_tables) |
||||
|
||||
bb.append_batch_tokens(torch.tensor([99, 99])) |
||||
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * ( |
||||
max_batch_size - bb.current_batch_size |
||||
) |
||||
|
||||
cache_manager.allocate_tokens_from_block_tables(bb.block_tables, bb.seq_lengths, bsz=bb.current_batch_size) |
||||
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * ( |
||||
max_batch_size - bb.current_batch_size |
||||
) |
||||
|
||||
bb.append_batch_tokens(torch.tensor([99, 99])) |
||||
|
||||
cache_manager.allocate_tokens_from_block_tables(bb.block_tables, bb.seq_lengths, bsz=bb.current_batch_size) |
||||
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * ( |
||||
max_batch_size - bb.current_batch_size |
||||
) |
||||
|
||||
bb.pop_seq_update_batch(0, free_block_table_fn=cache_manager.free_block_table) |
||||
assert bb.seq_lengths.tolist() == [bb.seqs_li[0].sentence_len] + [0] * (max_batch_size - bb.current_batch_size) |
||||
assert bb.is_compact |
||||
|
||||
bb2 = BatchBucket( |
||||
num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2 |
||||
) |
||||
block_tables = bb2.add_seqs([seq3]) |
||||
cache_manager.allocate_context_from_block_tables(block_tables, bb2.seq_lengths[: bb2.current_batch_size]) |
||||
unmerged_ids = bb.merge(bb2) |
||||
assert not unmerged_ids |
||||
assert bb.is_compact |
||||
assert bb2.is_compact |
||||
assert bb.current_batch_size == 2 |
||||
assert bb2.current_batch_size == 0 |
||||
|
||||
bb.clear(cache_manager.free_block_tables) |
||||
assert bb.current_batch_size == 0 |
||||
assert bb.is_compact |
||||
assert bb.seq_lengths.tolist() == [0] * max_batch_size |
||||
assert torch.all(bb.block_tables < 0) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
test_bucket() |
Loading…
Reference in new issue