mirror of https://github.com/hpcaitech/ColossalAI
[Inference] Optimize and Refactor Inference Batching/Scheduling (#5367)
* add kvcache manager funcs for batching * add batch bucket for batching * revise RunningList struct in handler * add kvcache/batch funcs for compatibility * use new batching methods * fix indexing bugs * revise abort logic * use cpu seq lengths/block tables * rm unused attr in Sequence * fix type conversion/default arg * add and revise pytests * revise pytests, rm unused tests * rm unused statements * fix pop finished indexing issue * fix: use index in batch when retrieving inputs/update seqs * use dict instead of odict in batch struct * arg type hinting * fix make compress * refine comments * fix: pop_n_seqs to pop the first n seqs * add check in request handler * remove redundant conversion * fix test for request handler * fix pop method in batch bucket * fix prefill addingpull/5399/head
parent
8c69debdc7
commit
b21aac5bae
|
@ -0,0 +1,449 @@
|
|||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class BatchBucket:
|
||||
"""Container for a batch of Sequences, which is used to manage the batch of sequences.
|
||||
|
||||
Attrs:
|
||||
_sequences_dict (Dict[int, Sequence]): Map sequence uid to sequence struct
|
||||
seq_uid -> Sequence
|
||||
_sequences_indexes (Dict[int, int]): Map sequence uid to index in the batch
|
||||
seq_uid -> index in the batch (indexing used in sequence_lengths and block_tables)
|
||||
_sequence_lengths (torch.Tensor): Length of each sequence in the batch.
|
||||
The size of the tensor is (max_batch_size,)
|
||||
_block_tables (torch.Tensor): Block table of each sequence in the batch
|
||||
The size of the tensor is (max_batch_size, max_blocks_per_seq)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
head_dim,
|
||||
max_batch_size,
|
||||
max_length,
|
||||
block_size,
|
||||
kv_max_split_num,
|
||||
fd_interm_tensor=None,
|
||||
device=None,
|
||||
dtype=torch.float16,
|
||||
):
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_length = max_length # in + out len
|
||||
self.block_size = block_size
|
||||
self.kv_max_split_num = kv_max_split_num # Hint used for flash decoding
|
||||
self.fd_interm_tensor = fd_interm_tensor
|
||||
self.device = device or get_current_device()
|
||||
self.dtype = dtype
|
||||
|
||||
self._current_batch_size = 0
|
||||
self._sequences_dict = dict()
|
||||
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
|
||||
self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32)
|
||||
self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths)
|
||||
max_blocks_per_seq = (self.max_length + block_size - 1) // block_size
|
||||
self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32)
|
||||
self._block_tables_helper = torch.full_like(self._block_tables, -1)
|
||||
|
||||
@property
|
||||
def is_empty(self):
|
||||
return self._current_batch_size == 0
|
||||
|
||||
@property
|
||||
def current_batch_size(self):
|
||||
return self._current_batch_size
|
||||
|
||||
@property
|
||||
def available_batch_size(self):
|
||||
return self.max_batch_size - self._current_batch_size
|
||||
|
||||
@property
|
||||
def block_tables(self):
|
||||
return self._block_tables
|
||||
|
||||
@property
|
||||
def seq_lengths(self):
|
||||
return self._sequence_lengths
|
||||
|
||||
@property
|
||||
def seqs_ids(self):
|
||||
return list(self._sequences_dict.keys())
|
||||
|
||||
@property
|
||||
def seqs_li(self):
|
||||
return list(self._sequences_dict.values())
|
||||
|
||||
@property
|
||||
def is_compact(self):
|
||||
assert len(self._sequences_dict) == len(self._sequences_indexes), "BatchBucket indexing is not consistent"
|
||||
return (
|
||||
len(self._sequences_dict)
|
||||
== torch.nonzero(self._sequence_lengths).view(-1).numel()
|
||||
== torch.nonzero(self._block_tables[:, 0] >= 0).numel()
|
||||
)
|
||||
|
||||
def _make_compact(self) -> None:
|
||||
# Clean and Compress the batch based on its sequences dict.
|
||||
# Namely,compress sequences to the front and clean the seq lengths and block tables tensors.
|
||||
# NOTE Prevent calling this method multiple times in a single step
|
||||
if self.is_compact:
|
||||
return
|
||||
valid_seq_ids = self._sequences_dict.keys()
|
||||
valid_num = len(valid_seq_ids)
|
||||
valid_indexes = [self._sequences_indexes[seq_id] for seq_id in valid_seq_ids]
|
||||
assert valid_num == len(self._sequences_indexes), "BatchBucket indexing is not consistent"
|
||||
self._sequence_lengths_helper[:valid_num] = self._sequence_lengths[valid_indexes]
|
||||
self._sequence_lengths[:] = self._sequence_lengths_helper[:]
|
||||
self._block_tables_helper[:valid_num, :] = self.block_tables[valid_indexes]
|
||||
self.block_tables[:] = self._block_tables_helper[:]
|
||||
new_idx = 0
|
||||
for seq_id in valid_seq_ids:
|
||||
self._sequences_indexes[seq_id] = new_idx
|
||||
new_idx += 1
|
||||
self._sequence_lengths_helper.fill_(0)
|
||||
self._block_tables_helper.fill_(-1)
|
||||
self._current_batch_size = valid_num
|
||||
|
||||
def add_seq(
|
||||
self,
|
||||
seq: Sequence,
|
||||
alloc_block_table: torch.Tensor = None,
|
||||
alloc_block_table_fn: Callable[[torch.Tensor, int], None] = None,
|
||||
) -> Union[torch.Tensor, None]:
|
||||
"""Add a single sequence to the batch.
|
||||
User could opt to provide either a block table or a function to allocate block tables.
|
||||
|
||||
Args:
|
||||
seq (Sequence): The sequence to be added to the batch
|
||||
alloc_block_table (torch.Tensor): The block tables to be copied and used for the sequence
|
||||
alloc_block_table_fn (Callable[[torch.Tensor, int], None]): The function to allocate blocks for the sequence,
|
||||
which is expected to reserve blocks and update status of kv-cache manager.
|
||||
|
||||
Returns:
|
||||
block_table (torch.Tensor): The block table of the added sequence, used for block allocation in kv-cache manager.
|
||||
None if the sequence cannot be added.
|
||||
"""
|
||||
block_table = None
|
||||
# TODO might consider sorting by length
|
||||
if self._current_batch_size < self.max_batch_size:
|
||||
self._sequences_dict[seq.request_id] = seq
|
||||
self._sequences_indexes[seq.request_id] = self._current_batch_size
|
||||
self._sequence_lengths[self._current_batch_size] = seq.sentence_len
|
||||
# NOTE the added seq still require block table allocation by kvcache manager
|
||||
block_table = self._block_tables[self._current_batch_size - 1]
|
||||
if alloc_block_table is not None:
|
||||
# copy block ids from provided block tables
|
||||
self._block_tables[self._current_batch_size - 1] = alloc_block_table
|
||||
elif alloc_block_table_fn:
|
||||
alloc_block_table_fn(block_table, self._sequence_lengths[self._current_batch_size - 1].item())
|
||||
self._current_batch_size += 1
|
||||
return block_table
|
||||
|
||||
def add_seqs(
|
||||
self,
|
||||
seqs: List[Sequence],
|
||||
alloc_block_tables: torch.Tensor = None,
|
||||
alloc_block_tables_fn: Callable[[torch.Tensor, torch.Tensor], None] = None,
|
||||
) -> Union[torch.Tensor, None]:
|
||||
"""Add a list of sequences to the batch.
|
||||
User could opt to provide either block tables or a function to allocate block tables.
|
||||
|
||||
Args:
|
||||
seqs (List[Sequence]): The sequences to be added to the batch
|
||||
alloc_block_tables (torch.Tensor): The block tables to be copied and used for the sequence
|
||||
alloc_block_table_fn (Callable[[torch.Tensor, torch.Tensor], None]): The function to allocate blocks for multiple sequences,
|
||||
which is expected to reserve blocks and update status of kv-cache manager.
|
||||
|
||||
Returns:
|
||||
block_tables (torch.Tensor): The block tables of the added sequences, used for block allocation in kv-cache manager.
|
||||
None if the sequences cannot be added.
|
||||
"""
|
||||
|
||||
assert (
|
||||
alloc_block_tables is None or alloc_block_tables_fn is None
|
||||
), "`alloc_block_tables` and `alloc_block_tables_fn` cannot be provided at the same time"
|
||||
|
||||
num_seqs_to_add = min(self.max_batch_size - self._current_batch_size, len(seqs))
|
||||
block_tables = None
|
||||
if num_seqs_to_add > 0:
|
||||
for i, seq in enumerate(seqs[:num_seqs_to_add]):
|
||||
self._sequences_dict[seq.request_id] = seq
|
||||
self._sequences_indexes[seq.request_id] = self._current_batch_size + i
|
||||
# TODO external (rename): modify Sequence.sentence_len to seq_len
|
||||
self._sequence_lengths[
|
||||
self._current_batch_size : self._current_batch_size + num_seqs_to_add
|
||||
] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
|
||||
# NOTE block tables to be updated by kvcache manager
|
||||
block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add]
|
||||
if alloc_block_tables is not None:
|
||||
# copy block ids from provided block tables
|
||||
self._block_tables[
|
||||
self._current_batch_size : self._current_batch_size + num_seqs_to_add
|
||||
] = alloc_block_tables
|
||||
elif alloc_block_tables_fn:
|
||||
alloc_block_tables_fn(
|
||||
block_tables,
|
||||
self._sequence_lengths[self._current_batch_size : self._current_batch_size + num_seqs_to_add],
|
||||
)
|
||||
|
||||
self._current_batch_size += num_seqs_to_add
|
||||
seqs[:] = seqs[num_seqs_to_add:]
|
||||
|
||||
return block_tables
|
||||
|
||||
def pop_seq_update_batch(
|
||||
self, request_id: int, free_block_table_fn: Callable[[torch.Tensor], None] = None
|
||||
) -> Tuple[Sequence, Union[torch.Tensor, None]]:
|
||||
"""Pop a single sequence by id from the batch, and update the batch bucket status.
|
||||
|
||||
Args:
|
||||
request_id (int): The uid of the sequence
|
||||
free_block_table_fn (Callable): The function to free the block table of a sequence,
|
||||
if not provided, then we have to release the block table manually after calling this method
|
||||
|
||||
Returns:
|
||||
A tuple of: seq (Sequence): The target sequence
|
||||
and block_table (torch.Tensor): block table of the target sequence indicating corresponding blocks,
|
||||
none if the sequence is not found or free_block_table_fn is provided.
|
||||
"""
|
||||
seq: Sequence = self._sequences_dict.get(request_id)
|
||||
block_table = None
|
||||
if seq is not None:
|
||||
assert request_id in self._sequences_indexes, "Inconsistency in BatchBucket indexing"
|
||||
self._sequences_dict.pop(request_id)
|
||||
seq_b_idx = self._sequences_indexes.get(request_id)
|
||||
|
||||
if self.current_batch_size > 1:
|
||||
# replace seq length of the target seq with that of the last seq in the batch
|
||||
last_seq_b_idx = self.current_batch_size - 1
|
||||
last_seq_id = next(
|
||||
(uid for uid, index in self._sequences_indexes.items() if index == last_seq_b_idx),
|
||||
None,
|
||||
)
|
||||
assert last_seq_id is not None
|
||||
self._sequences_indexes[last_seq_id] = seq_b_idx
|
||||
self._sequence_lengths[seq_b_idx] = self._sequence_lengths[last_seq_b_idx]
|
||||
self._sequence_lengths[last_seq_b_idx].fill_(0)
|
||||
# free the block table of the seq, or return a copy of the block table (to be processed outside)
|
||||
if free_block_table_fn:
|
||||
free_block_table_fn(self._block_tables[seq_b_idx])
|
||||
else:
|
||||
block_table = self._block_tables[seq_b_idx].detach().clone()
|
||||
# replace block table of the target seq with that of the last seq in the batch
|
||||
self._block_tables[seq_b_idx] = self._block_tables[last_seq_b_idx]
|
||||
self._block_tables[last_seq_b_idx].fill_(-1)
|
||||
else:
|
||||
if free_block_table_fn:
|
||||
free_block_table_fn(self._block_tables[0])
|
||||
else:
|
||||
block_table = self._block_tables[0].detach().clone()
|
||||
self._sequence_lengths[0].fill_(0)
|
||||
self._block_tables[0].fill_(-1)
|
||||
self._sequences_indexes.pop(request_id)
|
||||
self._current_batch_size -= 1
|
||||
|
||||
return seq, block_table
|
||||
|
||||
def pop_seqs(
|
||||
self, request_ids: List[int], free_block_table_fn: Callable[[torch.Tensor], None] = None
|
||||
) -> Tuple[List[Sequence], List[torch.Tensor]]:
|
||||
"""Iteratively pop a list of sequences by uid.
|
||||
|
||||
Args:
|
||||
request_ids (List[int]): The uids of the sequences
|
||||
free_block_table_fn (Callable): The function to free the block table of a sequence,
|
||||
if not provided, then we have to release the block table manually after calling this method
|
||||
Returns:
|
||||
A tuple of: seqs (List[Sequence]): The target sequences
|
||||
and block_tables (List[torch.Tensor]): block tables of the target sequences indicating corresponding blocks
|
||||
"""
|
||||
seqs = []
|
||||
block_tables = []
|
||||
for request_id in request_ids:
|
||||
seq, block_table = self.pop_seq_update_batch(request_id, free_block_table_fn)
|
||||
if seq is not None:
|
||||
seqs.append(seq)
|
||||
if block_table is not None:
|
||||
block_tables.append(block_table)
|
||||
return seqs, block_tables
|
||||
|
||||
def pop_n_seqs(
|
||||
self, n: int, free_block_table_fn: Callable[[torch.Tensor], None] = None
|
||||
) -> Tuple[List[Sequence], List[torch.Tensor]]:
|
||||
"""Pop the first n sequences in the batch (FIFO).
|
||||
If n is greater than the current batch szie, pop all the sequences in the batch.
|
||||
|
||||
Args:
|
||||
n (int): The number of sequences to pop out
|
||||
free_block_table_fn (Callable): The function to free the block table of a single sequence
|
||||
Returns:
|
||||
A tuple of: seqs (List[Sequence]): The target sequences,
|
||||
and block_tables (List[torch.Tensor]): block tables of the target sequences indicating corresponding blocks
|
||||
"""
|
||||
# NOTE Prevent calling this method multiple times in a single step
|
||||
seqs = []
|
||||
block_tables = []
|
||||
n = min(n, self.current_batch_size)
|
||||
seq_ids = list(self._sequences_dict.keys())[:n]
|
||||
for seq_id in seq_ids:
|
||||
seq = self._sequences_dict.pop(seq_id)
|
||||
seq_b_idx = self._sequences_indexes.pop(seq_id)
|
||||
if free_block_table_fn:
|
||||
free_block_table_fn(self.block_tables[seq_b_idx])
|
||||
else:
|
||||
block_tables.append(self.block_tables[seq_b_idx].detach().clone())
|
||||
seqs.append(seq)
|
||||
if not self.is_compact:
|
||||
self._make_compact()
|
||||
return seqs, block_tables
|
||||
|
||||
def pop_finished(
|
||||
self, free_block_table_fn: Callable[[torch.Tensor], None] = None
|
||||
) -> Tuple[List[Sequence], List[torch.Tensor]]:
|
||||
"""Pop finished sequences in the batch and a list of block tables of the finished sequences,
|
||||
if free_block_table_fn is not provided.
|
||||
|
||||
Args:
|
||||
free_block_table_fn (Callable): The function to free the block table of a single sequence
|
||||
Returns:
|
||||
A tuple of: finished_seqs (List[Sequence]): The finished sequences,
|
||||
and finished_block_tables (List[torch.Tensor]): block tables of the finished sequences.
|
||||
"""
|
||||
finished_seqs = []
|
||||
finished_block_tables = []
|
||||
for seq in self._sequences_dict.values():
|
||||
if seq.check_finish():
|
||||
finished_seqs.append(seq)
|
||||
# Use `pop_seq_update_batch`` to update the batch status for just a few of finished seqs,
|
||||
# otherwise, pop seqs directly and then call `_make_compact` to compress the batch.
|
||||
# For now, the performance difference is not significant, so we use the frist method to pop seqs.
|
||||
# Precise evaluations to be done.
|
||||
for seq in finished_seqs:
|
||||
_, block_table = self.pop_seq_update_batch(seq.request_id, free_block_table_fn)
|
||||
if block_table is not None:
|
||||
finished_block_tables.append(block_table)
|
||||
|
||||
return finished_seqs, finished_block_tables
|
||||
|
||||
# TODO arg type not support beam search sampling yet
|
||||
def append_batch_tokens(self, tokens: torch.Tensor) -> None:
|
||||
"""Append a batch of tokens to the sequences in the batch"""
|
||||
assert self.current_batch_size == tokens.size(0), "Batch size mismatch"
|
||||
|
||||
if self.current_batch_size > 0:
|
||||
tokens = tokens.tolist()
|
||||
for seq_id, seq in self._sequences_dict.items():
|
||||
index_in_b = self._sequences_indexes[seq_id]
|
||||
curr_tokens = tokens[index_in_b]
|
||||
if not isinstance(curr_tokens, list):
|
||||
curr_tokens = [curr_tokens]
|
||||
seq.output_token_id += curr_tokens
|
||||
seq.check_finish()
|
||||
self._sequence_lengths[: self.current_batch_size] += 1
|
||||
|
||||
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
|
||||
"""Clear all the sequences in the batch.
|
||||
|
||||
free_block_tables_fn (Optional[Callable]): The function to free the block tables of all the sequences in a batch
|
||||
"""
|
||||
seqs = list(self._sequences_dict.values())
|
||||
self._sequences_dict.clear()
|
||||
self._sequences_indexes.clear()
|
||||
if free_block_tables_fn:
|
||||
free_block_tables_fn(self.block_tables, self._current_batch_size)
|
||||
self._block_tables.fill_(-1)
|
||||
self._sequence_lengths.fill_(0)
|
||||
self._current_batch_size = 0
|
||||
return seqs
|
||||
|
||||
def merge(self, other: "BatchBucket") -> List[int]:
|
||||
"""Merge the sequences in the other batch into the current batch.
|
||||
Merge as possible as the current batch can, if it does not have available spaces
|
||||
holding all the sequences in the other batch
|
||||
|
||||
Usage:
|
||||
> New incoming sequence added to prefil batch
|
||||
prefill bb curr batch size < prefil_ratio * prefill bb max batch size
|
||||
> New incoming sequence added to prefil batch
|
||||
prefill bb curr batch size == prefil_ratio * prefill bb max batch size
|
||||
> Pause Decoding
|
||||
> Prefill
|
||||
> Move sequences in prefill bb => decoding bb
|
||||
> Put back the out-of-volume sequences into the running pool
|
||||
|
||||
Returns:
|
||||
unmerged_ids (List[int]): a list of sequence uids that are not merged into the current batch
|
||||
"""
|
||||
unmerged_ids = []
|
||||
num_seqs_to_merge = min(self.available_batch_size, other.current_batch_size)
|
||||
if num_seqs_to_merge > 0:
|
||||
seqs, block_tables_li = other.pop_n_seqs(num_seqs_to_merge)
|
||||
block_tables = torch.stack(block_tables_li)
|
||||
self.add_seqs(seqs, alloc_block_tables=block_tables)
|
||||
unmerged_ids = other.seqs_ids
|
||||
return unmerged_ids
|
||||
|
||||
########## The following methods are expected to be used in modeling ###########
|
||||
|
||||
# For compatibility.
|
||||
# NOTE: This is an assumption way to determine the stage of the batch.
|
||||
@property
|
||||
def is_prompts(self) -> bool:
|
||||
assert len(self._sequences_dict) > 0, "No sequence in the batch"
|
||||
first_seq = next(iter(self._sequences_dict.values()))
|
||||
if first_seq.output_len == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
# For compatibility
|
||||
def get_1D_inputs(self) -> torch.Tensor:
|
||||
assert len(self._sequences_dict) > 0, "No sequence in the batch"
|
||||
first_seq = next(iter(self._sequences_dict.values())) # not exactly the first sequence
|
||||
if first_seq.output_len == 0:
|
||||
# Assume prefill stage
|
||||
assert all(
|
||||
seq.output_len == 0 for seq in self._sequences_dict.values()
|
||||
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
|
||||
out_li = []
|
||||
num_tokens = torch.sum(self._sequence_lengths)
|
||||
out = torch.empty([num_tokens], dtype=torch.long)
|
||||
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
|
||||
for seq_id in seq_ids:
|
||||
seq: Sequence = self._sequences_dict[seq_id]
|
||||
out_li.extend(seq.input_token_id)
|
||||
return torch.tensor(out_li, dtype=torch.long, device=self.device)
|
||||
else:
|
||||
# Assume decoding stage
|
||||
assert all(
|
||||
seq.output_len > 0 for seq in self._sequences_dict.values()
|
||||
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
|
||||
assert self.is_compact, "BatchBucket is not compact"
|
||||
out = torch.empty([self.current_batch_size], dtype=torch.long)
|
||||
for seq_id, index_in_b in self._sequences_indexes.items():
|
||||
seq: Sequence = self._sequences_dict[seq_id]
|
||||
out[index_in_b] = seq.output_token_id[-1]
|
||||
return out.to(device=self.device)
|
||||
|
||||
# For compatibility
|
||||
def get_block_table_tensor(self) -> torch.Tensor:
|
||||
assert self.is_compact # Debug usage
|
||||
block_table = self.block_tables[: self.current_batch_size]
|
||||
return block_table.to(device=self.device)
|
||||
|
||||
# For compatibility
|
||||
def get_sequence_lengths(self) -> torch.Tensor:
|
||||
assert self.is_compact # Debug usage
|
||||
sequence_lengths = self.seq_lengths[: self.current_batch_size]
|
||||
return sequence_lengths.to(device=self.device)
|
||||
|
||||
# For compatibility
|
||||
@property
|
||||
def fd_inter_tensor(self) -> None:
|
||||
assert self.fd_interm_tensor is not None, "fd_interm_tensor is not provided"
|
||||
return self.fd_interm_tensor
|
|
@ -109,7 +109,7 @@ class InferenceConfig:
|
|||
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"
|
||||
|
||||
# check distributed
|
||||
assert (
|
||||
assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or (
|
||||
self.tp_size * self.pp_size == dist.get_world_size()
|
||||
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
|
||||
# check prompt template
|
||||
|
|
|
@ -42,7 +42,7 @@ class InferenceEngine:
|
|||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
tokenizer: [Union[PreTrainedTokenizer, PreTrainedTokenizerFast]],
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
inference_config: InferenceConfig,
|
||||
verbose: bool = False,
|
||||
model_policy: Policy = None,
|
||||
|
@ -254,20 +254,12 @@ class InferenceEngine:
|
|||
else:
|
||||
prompt = prompts[i]
|
||||
|
||||
max_blocks_per_sequence = (
|
||||
self.inference_config.max_input_len
|
||||
+ self.inference_config.max_output_len
|
||||
+ self.inference_config.block_size
|
||||
- 1
|
||||
) // self.inference_config.block_size
|
||||
block_table = torch.full([max_blocks_per_sequence], -1, device=self.device)
|
||||
sequence = Sequence(
|
||||
request_id,
|
||||
prompt,
|
||||
prompts_token_ids[i],
|
||||
block_size,
|
||||
None,
|
||||
block_table,
|
||||
self.tokenizer.eos_token_id,
|
||||
self.tokenizer.pad_token_id,
|
||||
self.inference_config.max_output_len,
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
from typing import List
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.kv_cache import KVCacheManager
|
||||
from colossalai.inference.logit_processors import logit_processor
|
||||
from colossalai.inference.sampler import *
|
||||
from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence
|
||||
from colossalai.inference.struct import RequestStatus, Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
__all__ = ["RunningList", "RequestHandler"]
|
||||
|
@ -24,45 +25,79 @@ class RunningList:
|
|||
|
||||
Args:
|
||||
prefill_ratio: (float) A ratio for determing whether to perform prefill or not.
|
||||
prefill: (List) List that contains default inputs, defaults to [].
|
||||
_prefill (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence.
|
||||
_decoding (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence.
|
||||
"""
|
||||
|
||||
def __init__(self, prefill_ratio: str, prefill: List[Sequence] = None):
|
||||
def __init__(self, prefill_ratio: int, prefill: List[Sequence] = None) -> None:
|
||||
self.prefill_ratio = prefill_ratio
|
||||
self.decoding: List[Sequence] = []
|
||||
self.prefill: List[Sequence] = prefill if prefill is not None else []
|
||||
self._decoding: Dict[int, Sequence] = dict()
|
||||
self._prefill: Dict[int, Sequence] = (
|
||||
dict({seq.request_id: seq for seq in self._prefill}) if prefill is not None else dict()
|
||||
)
|
||||
|
||||
@property
|
||||
def decoding(self):
|
||||
return list(self._decoding.values())
|
||||
|
||||
@property
|
||||
def prefill(self):
|
||||
return list(self._prefill.values())
|
||||
|
||||
@property
|
||||
def prefill_seq_num(self):
|
||||
return len(self._prefill)
|
||||
|
||||
@property
|
||||
def decoding_seq_num(self):
|
||||
return len(self._decoding)
|
||||
|
||||
@property
|
||||
def total_seq_num(self):
|
||||
return self.prefill_seq_num + self.decoding_seq_num
|
||||
|
||||
def append(self, seq: Sequence):
|
||||
# add seq to prefilling list first.
|
||||
self.prefill.append(seq)
|
||||
assert (seq.request_id not in self._prefill) and (
|
||||
seq.request_id not in self._decoding
|
||||
), f"Sequence uid {seq.request_id} already exists."
|
||||
self._prefill[seq.request_id] = seq
|
||||
|
||||
def find_seq(self, request_id):
|
||||
for seq in self.decoding:
|
||||
if request_id == seq.request_id:
|
||||
return seq
|
||||
for seq in self.prefill:
|
||||
if request_id == seq.request_id:
|
||||
return seq
|
||||
return None
|
||||
def extend(self, seqs: List[Sequence]):
|
||||
for seq in seqs:
|
||||
self._prefill[seq.request_id] = seq
|
||||
|
||||
def remove(self, seq: Sequence):
|
||||
if seq in self.decoding:
|
||||
self.decoding.remove(seq)
|
||||
elif seq in self.prefill:
|
||||
self.prefill.remove(seq)
|
||||
def find_seq(self, request_id) -> Union[Sequence, None]:
|
||||
seq = None
|
||||
if request_id in self._decoding:
|
||||
seq = self._decoding[request_id]
|
||||
elif request_id in self._prefill:
|
||||
seq = self._prefill[request_id]
|
||||
return seq
|
||||
|
||||
def remove(self, seq: Sequence) -> None:
|
||||
if seq.request_id in self._decoding:
|
||||
self._decoding.pop(seq.request_id)
|
||||
elif seq.request_id in self._prefill:
|
||||
self._prefill.pop(seq.request_id)
|
||||
else:
|
||||
raise ValueError(f"sequence {seq.request_id} is not in running list")
|
||||
raise ValueError(f"Sequence {seq.request_id} is not in running list")
|
||||
|
||||
def ready_for_prefill(self):
|
||||
if not self.decoding:
|
||||
return len(self.prefill) > 0
|
||||
return len(self.prefill) / len(self.decoding) >= self.prefill_ratio
|
||||
if not self._decoding:
|
||||
return len(self._prefill) > 0
|
||||
return len(self._prefill) / len(self._decoding) >= self.prefill_ratio
|
||||
|
||||
def is_empty(self):
|
||||
return not self.decoding and not self.prefill
|
||||
return not self._decoding and not self._prefill
|
||||
|
||||
def total_seq_num(self):
|
||||
return len(self.decoding) + len(self.prefill)
|
||||
def mark_prefill_running(self) -> None:
|
||||
for seq_id in self._prefill:
|
||||
self._prefill[seq_id].mark_running()
|
||||
|
||||
def move_prefill_to_decoding(self, seq_ids: List[int]) -> None:
|
||||
for seq_id in seq_ids:
|
||||
assert seq_id in self._prefill, f"Sequence {seq_id} is not in prefill list"
|
||||
self._decoding[seq_id] = self._prefill.pop(seq_id)
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
|
@ -110,25 +145,27 @@ class RequestHandler:
|
|||
|
||||
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
|
||||
# which may cause bugs and this issue should be fixed later.
|
||||
self.running_batch = BatchInfo(
|
||||
max_batch_size=self.max_batch_size,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
self.running_bb = BatchBucket(
|
||||
num_heads=model_config.num_attention_heads,
|
||||
head_dim=head_dim,
|
||||
is_prompts=False,
|
||||
device=device,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
||||
block_size=inference_config.block_size,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
fd_interm_tensor=fd_inter_tensor,
|
||||
dtype=self.dtype,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
device=device,
|
||||
)
|
||||
self.prefill_batch = BatchInfo(
|
||||
max_batch_size=self.max_batch_size,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
self.prefill_bb = BatchBucket(
|
||||
num_heads=model_config.num_attention_heads,
|
||||
head_dim=head_dim,
|
||||
is_prompts=True,
|
||||
device=device,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
||||
block_size=inference_config.block_size,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
fd_interm_tensor=fd_inter_tensor,
|
||||
dtype=self.dtype,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def _init_cache(self, model_config):
|
||||
|
@ -159,40 +196,39 @@ class RequestHandler:
|
|||
remove_list.append(seq)
|
||||
break
|
||||
|
||||
# stop feeding new sequence into running list to assure
|
||||
if self.cache_manager.num_available_blocks <= self.running_list.total_seq_num():
|
||||
break
|
||||
num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num)
|
||||
remove_list.extend(lst[:num_seqs_to_add])
|
||||
self.running_list.extend(lst[:num_seqs_to_add])
|
||||
|
||||
# Try to allocate cache blocks for the sequence.
|
||||
if (
|
||||
self.cache_manager.check_allocation(seq)
|
||||
and (len(self.running_list.prefill) + len(self.running_list.decoding))
|
||||
< self.max_batch_size # There some bugs in continous batching, so we disable it here.
|
||||
):
|
||||
# If succeed, add the sequence to running list.
|
||||
remove_list.append(seq)
|
||||
self.running_list.append(seq)
|
||||
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.sentence_len)
|
||||
for seq in remove_list:
|
||||
lst.remove(seq)
|
||||
|
||||
if self.running_list.ready_for_prefill():
|
||||
for seq in self.running_list.prefill:
|
||||
seq.mark_running()
|
||||
self.prefill_batch.add_seqs(self.running_list.prefill)
|
||||
return self.prefill_batch
|
||||
num_seqs_to_add = min(self.running_list.prefill_seq_num, self.running_bb.available_batch_size)
|
||||
|
||||
if not self.running_batch.is_empty:
|
||||
for seq in self.running_batch.sequences_set:
|
||||
recycle = self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
|
||||
if recycle:
|
||||
for seq in self.running_list.prefill[:num_seqs_to_add]:
|
||||
seq.mark_running()
|
||||
# allocate blocks for the prefill batch
|
||||
self.prefill_bb.add_seqs(
|
||||
self.running_list.prefill[:num_seqs_to_add],
|
||||
alloc_block_tables_fn=self.cache_manager.allocate_context_from_block_tables,
|
||||
)
|
||||
|
||||
return self.prefill_bb
|
||||
|
||||
if not self.running_bb.is_empty:
|
||||
seqs_ids_to_recycle = self.cache_manager.allocate_tokens_from_block_tables(
|
||||
self.running_bb.block_tables, self.running_bb.seq_lengths, self.running_bb.current_batch_size
|
||||
)
|
||||
if seqs_ids_to_recycle:
|
||||
seqs_to_recycle = self.running_bb.pop_seqs(seqs_ids_to_recycle)
|
||||
for seq in seqs_to_recycle:
|
||||
seq.recycle()
|
||||
self.running_batch.del_seq(seq)
|
||||
self.running_list.remove(seq)
|
||||
self.waiting_list[-1].append(seq)
|
||||
# the recycled sequences are handled with highest priority.
|
||||
|
||||
return self.running_batch
|
||||
return self.running_bb
|
||||
|
||||
def add_sequence(self, req: Sequence):
|
||||
"""
|
||||
|
@ -213,7 +249,7 @@ class RequestHandler:
|
|||
seq.mark_aborted()
|
||||
self.waiting_list[priority].remove(seq)
|
||||
elif seq.status.is_running():
|
||||
self.cache_manager.free_block_table(seq.block_table)
|
||||
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
|
||||
self.running_list.remove(seq)
|
||||
else:
|
||||
try:
|
||||
|
@ -242,7 +278,7 @@ class RequestHandler:
|
|||
else:
|
||||
sample_tokens = greedy_sample(generation_config, logprobs)
|
||||
else:
|
||||
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_batch.is_empty)
|
||||
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_bb.is_empty)
|
||||
|
||||
return sample_tokens
|
||||
|
||||
|
@ -273,27 +309,25 @@ class RequestHandler:
|
|||
|
||||
# sample the next tokens
|
||||
sample_tokens = self._sample(probs, logprobs, generation_config)
|
||||
if not self.prefill_batch.is_empty:
|
||||
self.prefill_batch.update_batch_tokens(sample_tokens)
|
||||
if not self.prefill_bb.is_empty:
|
||||
self.prefill_bb.append_batch_tokens(sample_tokens)
|
||||
else:
|
||||
self.running_batch.update_batch_tokens(sample_tokens)
|
||||
self.running_bb.append_batch_tokens(sample_tokens)
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
Update current running list and done list
|
||||
"""
|
||||
if not self.prefill_batch.is_empty:
|
||||
self.running_list.decoding.extend(self.running_list.prefill)
|
||||
self.running_batch.add_seqs(self.running_list.prefill)
|
||||
self.running_list.prefill.clear()
|
||||
self.prefill_batch.clear_batch()
|
||||
if not self.prefill_bb.is_empty:
|
||||
self.running_list.move_prefill_to_decoding(self.prefill_bb.seqs_ids)
|
||||
self.running_bb.merge(self.prefill_bb)
|
||||
# clear the prefill batch without assigning a free_block_tables_fn
|
||||
# since we want to reuse the memory recorded on the block tables
|
||||
self.prefill_bb.clear(free_block_tables_fn=None)
|
||||
|
||||
finish_seqs = self.running_batch.fliter_batch()
|
||||
|
||||
for seq in finish_seqs:
|
||||
finished_seqs, _ = self.running_bb.pop_finished(self.cache_manager.free_block_table)
|
||||
for seq in finished_seqs:
|
||||
self.running_list.remove(seq)
|
||||
self.cache_manager.free_block_table(seq.block_table)
|
||||
self.done_list.extend(finished_seqs)
|
||||
|
||||
self.done_list.extend(finish_seqs)
|
||||
|
||||
return finish_seqs
|
||||
return finished_seqs
|
||||
|
|
|
@ -63,7 +63,6 @@ class KVCacheManager:
|
|||
self.dtype = config.dtype
|
||||
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
|
||||
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
|
||||
# For now we focus on MHA only, TODO add handling for MQA and GQA
|
||||
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
|
||||
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
|
||||
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
|
||||
|
@ -82,8 +81,8 @@ class KVCacheManager:
|
|||
|
||||
# Physical cache allocation
|
||||
alloc_shape = (self.num_blocks, self.head_num, self.block_size, self.head_size)
|
||||
if verbose:
|
||||
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
|
||||
# if verbose:
|
||||
# self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
|
||||
self._kv_caches = self._init_device_caches(alloc_shape)
|
||||
self.total_physical_cache_size_in_bytes = (
|
||||
self.elem_size_in_bytes
|
||||
|
@ -112,6 +111,9 @@ class KVCacheManager:
|
|||
"""Get the number of available cache blocks."""
|
||||
return self._available_blocks
|
||||
|
||||
def get_head_size(self):
|
||||
return self.head_size
|
||||
|
||||
def get_kv_cache(self):
|
||||
"""Get k_cache and v_cache"""
|
||||
return self._kv_caches
|
||||
|
@ -148,7 +150,7 @@ class KVCacheManager:
|
|||
and updates the provided block table with the allocated block ids.
|
||||
|
||||
Args:
|
||||
block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
|
||||
block_table: A 1D tensor of shape [max_blocks_per_sequence], mapping of token_position_id -> block_id.
|
||||
context_len: The length of the processing sequnece.
|
||||
"""
|
||||
assert block_table.dim() == 1
|
||||
|
@ -193,12 +195,85 @@ class KVCacheManager:
|
|||
else:
|
||||
self._allocate_on_block(block, block.block_size)
|
||||
|
||||
def allocate_context_from_block_tables(self, block_tables: torch.Tensor, context_lengths: torch.Tensor) -> None:
|
||||
"""Allocate logical cache blocks for a batch of sequences during prefill stage.
|
||||
|
||||
Args:
|
||||
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence]
|
||||
context_lengths (torch.Tensor): [bsz]]
|
||||
"""
|
||||
assert block_tables.dim() == 2
|
||||
assert block_tables.size(0) == context_lengths.size(0)
|
||||
if not torch.all(block_tables < 0):
|
||||
self.logger.error("Some slots on provided block table have been allocated.")
|
||||
blocks_required = (context_lengths + self.block_size - 1) // self.block_size
|
||||
num_blocks_required = torch.sum(blocks_required).item()
|
||||
assert isinstance(num_blocks_required, int)
|
||||
if num_blocks_required > self._available_blocks:
|
||||
self.logger.warning(
|
||||
f"Lacking blocks to allocate. Available blocks {self._available_blocks}; blocks asked {num_blocks_required}."
|
||||
)
|
||||
return
|
||||
|
||||
bsz = block_tables.size(0)
|
||||
# Try contiguous allocation
|
||||
torch.cumsum(self._block_states, dim=-1, out=self._block_states_cum[1:])
|
||||
torch.subtract(
|
||||
self._block_states_cum[num_blocks_required:],
|
||||
self._block_states_cum[:-num_blocks_required],
|
||||
out=self._block_finder[num_blocks_required - 1 :],
|
||||
)
|
||||
end_indexes = torch.nonzero(self._block_finder == num_blocks_required, as_tuple=False).view(-1)
|
||||
if end_indexes.numel() > 0:
|
||||
# contiguous cache exists
|
||||
end_idx = end_indexes[0].item() + 1 # open interval
|
||||
start_idx = end_idx - num_blocks_required # closed interval
|
||||
alloc_block_ids = torch.arange(start_idx, end_idx)
|
||||
for i in range(bsz):
|
||||
curr_required = blocks_required[i]
|
||||
block_tables[i, :curr_required] = torch.arange(
|
||||
start_idx, start_idx + curr_required, device=block_tables.device
|
||||
)
|
||||
start_idx += curr_required
|
||||
else:
|
||||
# non-contiguous cache
|
||||
available_block_ids = torch.nonzero(self._block_states > 0).view(-1)
|
||||
alloc_block_ids = available_block_ids[:num_blocks_required]
|
||||
alloc_block_ids = alloc_block_ids.to(dtype=block_tables.dtype, device=block_tables.device)
|
||||
start_idx = 0
|
||||
for i in range(bsz):
|
||||
curr_required = blocks_required[i]
|
||||
block_tables[i, :curr_required] = alloc_block_ids[start_idx, start_idx + curr_required]
|
||||
start_idx += curr_required
|
||||
|
||||
# Update cache blocks
|
||||
self._block_states[alloc_block_ids] = 0
|
||||
self._available_blocks -= num_blocks_required
|
||||
last_block_locs = torch.cumsum(blocks_required, dim=0) - 1
|
||||
last_block_locs = last_block_locs.to(device=alloc_block_ids.device)
|
||||
|
||||
for i, block_id in enumerate(alloc_block_ids[last_block_locs]):
|
||||
block: CacheBlock = self._cache_blocks[block_id]
|
||||
block.add_ref()
|
||||
self._allocate_on_block(
|
||||
block,
|
||||
block.block_size
|
||||
if context_lengths[i] % block.block_size == 0
|
||||
else context_lengths[i].item() % block.block_size,
|
||||
)
|
||||
for block_id in alloc_block_ids:
|
||||
if block_id in alloc_block_ids[last_block_locs]:
|
||||
continue
|
||||
block: CacheBlock = self._cache_blocks[block_id]
|
||||
block.add_ref()
|
||||
self._allocate_on_block(block, block.block_size)
|
||||
|
||||
def allocate_token_from_block_table(self, block_table: torch.Tensor, context_len: int) -> None:
|
||||
"""Allocate the logical cache block for a single sequence during decoding stage,
|
||||
and updates the provided block table if a new cache block is needed.
|
||||
|
||||
Args:
|
||||
block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
|
||||
block_table: A 1D tensor of shape [max_blocks_per_sequence], mapping of token_position_id -> block_id.
|
||||
context_len: The length of the processing sequnece (already-allocated length).
|
||||
"""
|
||||
assert block_table.dim() == 1
|
||||
|
@ -207,12 +282,79 @@ class KVCacheManager:
|
|||
alloc_local_block_idx = context_len // self.block_size
|
||||
return self.allocate_single_block(block_table, alloc_local_block_idx)
|
||||
|
||||
def allocate_tokens_from_block_tables(
|
||||
self, block_tables: torch.Tensor, context_lens: torch.Tensor, bsz: int = None
|
||||
) -> List[int]:
|
||||
"""Allocate logical cache blocks for a batch of sequences during decoding stage.
|
||||
|
||||
Usage:
|
||||
allocate_context_from_block_tables
|
||||
model forward (block tables & context lengths passed)
|
||||
update context lengths
|
||||
allocate_tokens_from_block_tables
|
||||
model forward
|
||||
update context lengths
|
||||
allocate_tokens_from_block_tables
|
||||
model forward
|
||||
update context lengths
|
||||
...
|
||||
|
||||
Args:
|
||||
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence]
|
||||
context_lengths (torch.Tensor): [bsz]
|
||||
|
||||
Returns:
|
||||
List[int]: list of sequence uid to be recycled
|
||||
"""
|
||||
assert block_tables.dim() == 2
|
||||
assert context_lens.dim() == 1
|
||||
|
||||
bsz = block_tables.size(0) if bsz is None else bsz
|
||||
|
||||
alloc_local_block_indexes = (context_lens[:bsz]) // self.block_size
|
||||
block_global_ids = block_tables[torch.arange(0, bsz), alloc_local_block_indexes]
|
||||
seqs_to_recycle = []
|
||||
new_blocks_required = torch.sum(block_global_ids < 0).item()
|
||||
seqs_req_new_blocks = torch.nonzero(block_global_ids < 0).squeeze()
|
||||
|
||||
if new_blocks_required > 0:
|
||||
if new_blocks_required > self._available_blocks:
|
||||
# TODO might want to revise the logic here
|
||||
# Process the first (_available_blocks) sequences that require new blocks
|
||||
# Put the rest of the sequences back to recycled
|
||||
seqs_req_new_blocks, seqs_to_recycle = (
|
||||
seqs_req_new_blocks[: self._available_blocks],
|
||||
seqs_req_new_blocks[self._available_blocks :],
|
||||
)
|
||||
for seq_id in seqs_to_recycle:
|
||||
self.free_block_table(block_tables[seq_id])
|
||||
new_blocks_required = self._available_blocks
|
||||
|
||||
# NOTE might want to alloc contiguous logic
|
||||
free_block_ids = torch.nonzero(self._block_states > 0).view(-1)
|
||||
alloc_block_ids = free_block_ids[:new_blocks_required].to(
|
||||
dtype=block_tables.dtype, device=block_tables.device
|
||||
)
|
||||
|
||||
for block_id in alloc_block_ids:
|
||||
block: CacheBlock = self._cache_blocks[block_id]
|
||||
block.add_ref()
|
||||
self._block_states[block_id] = 0
|
||||
self._available_blocks -= 1
|
||||
block_tables[seqs_req_new_blocks, alloc_local_block_indexes[seqs_req_new_blocks]] = alloc_block_ids
|
||||
block_global_ids = block_tables[torch.arange(0, bsz), alloc_local_block_indexes]
|
||||
|
||||
for block_id in block_global_ids:
|
||||
self._allocate_on_block(self._cache_blocks[block_id], 1)
|
||||
|
||||
return seqs_to_recycle
|
||||
|
||||
def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int) -> int:
|
||||
"""Allocate space asked on a single block in the block table, specified by the provided position id,
|
||||
and updates the provided block table with the allocated block.
|
||||
|
||||
Args:
|
||||
block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
|
||||
block_table: A 1D tensor of shape [max_blocks_per_sequence], mapping of token_position_id -> block_id.
|
||||
block_local_idx: The index of the block in the block table.
|
||||
space_asked: i.e. The number of tokens to be assigned space for.
|
||||
Returns:
|
||||
|
@ -240,8 +382,7 @@ class KVCacheManager:
|
|||
def free_block_table(self, block_table: torch.Tensor) -> None:
|
||||
"""Free the logical cache blocks for **a single sequence**."""
|
||||
assert block_table.dim() == 1
|
||||
for i in range(block_table.numel()):
|
||||
global_block_id = block_table[i].item()
|
||||
for i, global_block_id in enumerate(block_table.tolist()):
|
||||
if global_block_id < 0:
|
||||
return
|
||||
block: CacheBlock = self._cache_blocks[global_block_id]
|
||||
|
@ -253,6 +394,15 @@ class KVCacheManager:
|
|||
# reset the block id in the block table (if we maintain a 2D tensors as block tables in Engine)
|
||||
block_table[i] = -1
|
||||
|
||||
def free_block_tables(self, block_tables: torch.Tensor, first_n: int = None) -> None:
|
||||
"""Release the logical cache blocks for a batch of sequences.
|
||||
If `first_n` is provided, only the blocks for the first several sequences will be released.
|
||||
"""
|
||||
assert block_tables.dim() == 2
|
||||
first_n = block_tables.size(0) if first_n is None else first_n
|
||||
for block_table in block_tables[:first_n]:
|
||||
self.free_block_table(block_table)
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all the references and allocations on all the cache blocks."""
|
||||
for block in self._cache_blocks:
|
||||
|
|
|
@ -12,8 +12,8 @@ from transformers.models.llama.modeling_llama import (
|
|||
LlamaModel,
|
||||
)
|
||||
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.struct import BatchInfo
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
copy_kv_to_blocked_cache,
|
||||
|
@ -34,7 +34,7 @@ except ImportError:
|
|||
|
||||
def llama_causal_lm_forward(
|
||||
self: LlamaForCausalLM,
|
||||
batch: BatchInfo = None,
|
||||
batch: BatchBucket = None,
|
||||
k_caches: List[torch.Tensor] = None,
|
||||
v_caches: List[torch.Tensor] = None,
|
||||
):
|
||||
|
@ -59,7 +59,7 @@ def llama_causal_lm_forward(
|
|||
|
||||
def llama_model_forward(
|
||||
self: LlamaModel,
|
||||
batch: BatchInfo = None,
|
||||
batch: BatchBucket = None,
|
||||
k_caches: List[torch.Tensor] = None,
|
||||
v_caches: List[torch.Tensor] = None,
|
||||
):
|
||||
|
@ -73,7 +73,7 @@ def llama_model_forward(
|
|||
input_ids = batch.get_1D_inputs()
|
||||
block_tables = batch.get_block_table_tensor()
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
batch_size = len(sequence_lengths)
|
||||
batch_size = batch.current_batch_size
|
||||
kv_seq_len = sequence_lengths.max().item()
|
||||
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
|
|
@ -71,7 +71,6 @@ class Sequence:
|
|||
input_token_id: List[int]
|
||||
block_size: int
|
||||
sample_params: Any # SampleParams needs to be imported later.
|
||||
block_table: torch.Tensor
|
||||
eos_token_id: int
|
||||
pad_token_id: int
|
||||
max_output_len: int = 256
|
||||
|
@ -158,7 +157,6 @@ class Sequence:
|
|||
f"prompt={self.prompt}, "
|
||||
f"status={self.status.name}, "
|
||||
f"sample_params={self.sample_params}, "
|
||||
f"logical_block_number={self.block_table.shape[0]},"
|
||||
f"input_len={self.input_len}),"
|
||||
f"output_len={self.output_len})"
|
||||
)
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
import torch
|
||||
from transformers.models.llama import LlamaConfig
|
||||
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.kv_cache import KVCacheManager
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.testing import parameterize
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"hidden_size": 128,
|
||||
"num_attention_heads": 4,
|
||||
"num_layers": 2,
|
||||
"block_size": 4,
|
||||
"max_batch_size": 4,
|
||||
"max_input_len": 32,
|
||||
"max_output_len": 8,
|
||||
"dtype": torch.float16,
|
||||
"tp_size": 1,
|
||||
}
|
||||
],
|
||||
)
|
||||
def test_bucket(test_config):
|
||||
hidden_size = test_config.pop("hidden_size")
|
||||
num_heads = test_config.pop("num_attention_heads")
|
||||
num_layers = test_config.pop("num_layers")
|
||||
model_config = LlamaConfig(
|
||||
hidden_size=hidden_size,
|
||||
num_hidden_layers=num_layers,
|
||||
num_attention_heads=num_heads,
|
||||
)
|
||||
inference_config = InferenceConfig(**test_config)
|
||||
|
||||
# Just for testing usage. Don't create multiple cache_manager on the same device.
|
||||
cache_manager = KVCacheManager(inference_config, model_config)
|
||||
cache_manager_copy = KVCacheManager(inference_config, model_config)
|
||||
|
||||
seq_lens = [19, 20, 27]
|
||||
seq1 = Sequence(
|
||||
request_id=0,
|
||||
prompt="", # Dummy for testing usage
|
||||
input_token_id=list(range(seq_lens[0])),
|
||||
block_size=4,
|
||||
sample_params=None,
|
||||
eos_token_id=2,
|
||||
pad_token_id=2,
|
||||
max_output_len=10,
|
||||
)
|
||||
seq2 = Sequence(
|
||||
request_id=1,
|
||||
prompt="", # Dummy for testing usage
|
||||
input_token_id=list(range(seq_lens[1])),
|
||||
block_size=4,
|
||||
sample_params=None,
|
||||
eos_token_id=2,
|
||||
pad_token_id=2,
|
||||
max_output_len=10,
|
||||
)
|
||||
seq3 = Sequence(
|
||||
request_id=2,
|
||||
prompt="", # Dummy for testing usage
|
||||
input_token_id=list(range(seq_lens[2])),
|
||||
block_size=4,
|
||||
sample_params=None,
|
||||
eos_token_id=2,
|
||||
pad_token_id=2,
|
||||
max_output_len=10,
|
||||
)
|
||||
|
||||
block_size = test_config["block_size"]
|
||||
max_batch_size = test_config["max_batch_size"]
|
||||
max_length = test_config["max_input_len"] + test_config["max_output_len"]
|
||||
assert max_batch_size >= 2, "max_batch_size should be greater than 1"
|
||||
|
||||
bb = BatchBucket(
|
||||
num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2
|
||||
)
|
||||
bb_copy = BatchBucket(
|
||||
num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2
|
||||
)
|
||||
block_tables = bb.add_seqs([seq1, seq2])
|
||||
assert block_tables.shape == (2, cache_manager.max_blocks_per_sequence)
|
||||
assert torch.all(block_tables < 0), "Initialized block_tables should be negative values"
|
||||
|
||||
cache_manager.allocate_context_from_block_tables(block_tables, bb.seq_lengths[: bb.current_batch_size])
|
||||
bb_copy.add_seqs(
|
||||
[seq1, seq2], alloc_block_tables_fn=cache_manager_copy.allocate_context_from_block_tables
|
||||
) # This is just for testing usage. Don't add the same sequence to different buckets.
|
||||
|
||||
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * (
|
||||
max_batch_size - bb.current_batch_size
|
||||
)
|
||||
assert torch.equal(bb.block_tables, bb_copy.block_tables)
|
||||
|
||||
bb.append_batch_tokens(torch.tensor([99, 99]))
|
||||
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * (
|
||||
max_batch_size - bb.current_batch_size
|
||||
)
|
||||
|
||||
cache_manager.allocate_tokens_from_block_tables(bb.block_tables, bb.seq_lengths, bsz=bb.current_batch_size)
|
||||
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * (
|
||||
max_batch_size - bb.current_batch_size
|
||||
)
|
||||
|
||||
bb.append_batch_tokens(torch.tensor([99, 99]))
|
||||
|
||||
cache_manager.allocate_tokens_from_block_tables(bb.block_tables, bb.seq_lengths, bsz=bb.current_batch_size)
|
||||
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * (
|
||||
max_batch_size - bb.current_batch_size
|
||||
)
|
||||
|
||||
bb.pop_seq_update_batch(0, free_block_table_fn=cache_manager.free_block_table)
|
||||
assert bb.seq_lengths.tolist() == [bb.seqs_li[0].sentence_len] + [0] * (max_batch_size - bb.current_batch_size)
|
||||
assert bb.is_compact
|
||||
|
||||
bb2 = BatchBucket(
|
||||
num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2
|
||||
)
|
||||
block_tables = bb2.add_seqs([seq3])
|
||||
cache_manager.allocate_context_from_block_tables(block_tables, bb2.seq_lengths[: bb2.current_batch_size])
|
||||
unmerged_ids = bb.merge(bb2)
|
||||
assert not unmerged_ids
|
||||
assert bb.is_compact
|
||||
assert bb2.is_compact
|
||||
assert bb.current_batch_size == 2
|
||||
assert bb2.current_batch_size == 0
|
||||
|
||||
bb.clear(cache_manager.free_block_tables)
|
||||
assert bb.current_batch_size == 0
|
||||
assert bb.is_compact
|
||||
assert bb.seq_lengths.tolist() == [0] * max_batch_size
|
||||
assert torch.all(bb.block_tables < 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_bucket()
|
|
@ -15,7 +15,6 @@ def check_config_and_inference():
|
|||
input_token_id=[1, 2, 3],
|
||||
block_size=16,
|
||||
sample_params=None,
|
||||
block_table=None,
|
||||
eos_token_id=2,
|
||||
pad_token_id=2,
|
||||
max_output_len=256,
|
||||
|
@ -27,7 +26,6 @@ def check_config_and_inference():
|
|||
input_token_id=[4, 5, 6],
|
||||
block_size=16,
|
||||
sample_params=None,
|
||||
block_table=None,
|
||||
eos_token_id=2,
|
||||
pad_token_id=2,
|
||||
max_output_len=256,
|
||||
|
@ -39,7 +37,6 @@ def check_config_and_inference():
|
|||
input_token_id=[7, 8, 9],
|
||||
block_size=16,
|
||||
sample_params=None,
|
||||
block_table=None,
|
||||
eos_token_id=2,
|
||||
pad_token_id=2,
|
||||
max_output_len=256,
|
||||
|
|
|
@ -148,6 +148,20 @@ def check_cache_manager(test_config):
|
|||
cache_manager.clear_all()
|
||||
assert cache_manager.num_available_blocks == num_blocks
|
||||
|
||||
for cache_block in cache_manager._cache_blocks:
|
||||
assert cache_block.available_space == block_size
|
||||
|
||||
# Mock batch operations (Prefill/Decoding updates)
|
||||
context_lengths = torch.tensor([max_input_length, max_input_length - 1])
|
||||
block_tables = torch.tensor(
|
||||
[[-1 for _ in range(cache_manager.max_blocks_per_sequence)] for _ in range(2)], dtype=torch.int32
|
||||
)
|
||||
cache_manager.allocate_context_from_block_tables(block_tables, context_lengths)
|
||||
cache_manager.allocate_tokens_from_block_tables(block_tables, context_lengths)
|
||||
cache_manager.free_block_tables(block_tables)
|
||||
for cache_block in cache_manager._cache_blocks:
|
||||
assert cache_block.available_space == block_size
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import pytest
|
||||
import torch
|
||||
from transformers.models.llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
|
@ -22,17 +21,35 @@ def check_running_list():
|
|||
eos_token_id=0,
|
||||
pad_token_id=0,
|
||||
sample_params=None,
|
||||
block_table=1,
|
||||
)
|
||||
|
||||
seq2 = Sequence(
|
||||
request_id=2,
|
||||
prompt="abc",
|
||||
input_token_id=[1, 2, 3],
|
||||
block_size=16,
|
||||
eos_token_id=0,
|
||||
pad_token_id=0,
|
||||
sample_params=None,
|
||||
)
|
||||
running_list.append(seq1)
|
||||
running_list.append(seq2)
|
||||
assert running_list.ready_for_prefill()
|
||||
assert running_list.decoding == [] and running_list.prefill[0] == seq1
|
||||
assert len(running_list.decoding) == 0
|
||||
assert len(running_list.prefill) > 0 and running_list.prefill[0] == seq1
|
||||
|
||||
seq = running_list.find_seq(seq1.request_id)
|
||||
assert seq == seq1
|
||||
|
||||
running_list.mark_prefill_running()
|
||||
for seq in running_list.prefill:
|
||||
assert seq.status == RequestStatus.RUNNING
|
||||
|
||||
running_list.move_prefill_to_decoding([seq1.request_id, seq2.request_id])
|
||||
assert len(running_list.prefill) == 0
|
||||
assert len(running_list.decoding) > 0 and running_list.decoding[0] == seq1
|
||||
|
||||
running_list.remove(seq1)
|
||||
running_list.remove(seq2)
|
||||
assert running_list.is_empty()
|
||||
|
||||
|
||||
|
@ -59,7 +76,6 @@ def check_request_handler():
|
|||
eos_token_id=0,
|
||||
pad_token_id=0,
|
||||
sample_params=None,
|
||||
block_table=torch.tensor([-1, -1]),
|
||||
)
|
||||
request_handler.add_sequence(seq1)
|
||||
# the priority should be 1
|
||||
|
|
Loading…
Reference in New Issue