You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/inference/batch_bucket.py

561 lines
25 KiB

from typing import Callable, List, Optional, Tuple, Union
import torch
from colossalai.inference.struct import Sequence
from colossalai.utils import get_current_device
class BatchBucket:
"""Container for a batch of Sequences, which is used to manage the batch of sequences.
Attrs:
_sequences_dict (Dict[int, Sequence]): Map sequence uid to sequence struct
seq_uid -> Sequence
_sequences_indexes (Dict[int, int]): Map sequence uid to index in the batch
seq_uid -> index in the batch (indexing used in sequence_lengths and block_tables)
_sequence_lengths (torch.Tensor): Length of each sequence in the batch.
The size of the tensor is (max_batch_size,)
_block_tables (torch.Tensor): Block table of each sequence in the batch
The size of the tensor is (max_batch_size, max_blocks_per_seq)
"""
def __init__(
self,
num_heads,
head_dim,
max_batch_size,
max_length,
block_size,
kv_max_split_num,
fd_interm_tensor=None,
device=None,
dtype=torch.float16,
enable_streamingllm: bool = False,
start_token_size: int = 4,
generated_token_size: int = 512,
):
self.num_heads = num_heads
self.head_dim = head_dim
self.max_batch_size = max_batch_size
self.max_length = max_length # in + out len
self.block_size = block_size
self.kv_max_split_num = kv_max_split_num # Hint used for flash decoding
self.fd_interm_tensor = fd_interm_tensor
self.device = device or get_current_device()
self.dtype = dtype
self._use_spec_dec = False
self._num_tokens_to_verify = None
self.enable_streamingllm = enable_streamingllm
self.start_token_size = start_token_size
self.generated_token_size = generated_token_size
self._current_batch_size = 0
self._sequences_dict = dict()
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32)
self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths)
if enable_streamingllm:
max_blocks_per_seq = (start_token_size + generated_token_size + block_size - 1) // block_size + 1
else:
max_blocks_per_seq = (self.max_length + block_size - 1) // block_size
self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32)
self._block_tables_helper = torch.full_like(self._block_tables, -1)
@property
def is_empty(self):
return self._current_batch_size == 0
@property
def current_batch_size(self):
return self._current_batch_size
def __len__(self):
return self._current_batch_size
@property
def available_batch_size(self):
return self.max_batch_size - self._current_batch_size
@property
def block_tables(self):
return self._block_tables
@property
def seq_lengths(self):
return self._sequence_lengths
@property
def seqs_ids(self):
return list(self._sequences_dict.keys())
@property
def seqs_li(self):
return list(self._sequences_dict.values())
@property
def is_compact(self):
assert len(self._sequences_dict) == len(self._sequences_indexes), "BatchBucket indexing is not consistent"
return (
len(self._sequences_dict)
== torch.nonzero(self._sequence_lengths).view(-1).numel()
== torch.nonzero(self._block_tables[:, 0] >= 0).numel()
)
@property
def use_spec_dec(self) -> bool:
return self._use_spec_dec
@property
def num_tokens_to_verify(self) -> int:
return self._num_tokens_to_verify
@property
def batch_token_ids(self) -> List[List[int]]:
out = []
for seq in self.seqs_li:
out.append(seq.input_token_id + seq.output_token_id)
return out
def streamingllm_update_batch(self, start_token_size: int, generated_token_size: int):
"""
Update sequence_lengths and block_tables when it is necessary to swap out a block.
"""
updated_block_ids = []
if self.current_batch_size > 0:
need_update = False
sequence_lengths_list = self._sequence_lengths.tolist()
block_tables_list = self._block_tables[: self._current_batch_size].tolist()
for batch_id in range(self.current_batch_size):
# We assume that the start token occupies the entire first block.
if sequence_lengths_list[batch_id] == start_token_size + generated_token_size + self.block_size - 1:
need_update = True
sequence_lengths_list[batch_id] = start_token_size + generated_token_size - 1
block_id = block_tables_list[batch_id].pop(1)
updated_block_ids.append(block_id)
block_tables_list[batch_id].append(-1)
if need_update:
self._sequence_lengths = torch.tensor(
sequence_lengths_list, dtype=self._sequence_lengths.dtype, device=self.device
)
self._block_tables = torch.tensor(block_tables_list, dtype=self._block_tables.dtype, device=self.device)
return updated_block_ids
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
"""Set batch bucket to use speculatvie decoding.
This will notify the adjust the lengths of inputs during modeling,
and let the main model verifies tokens in parallel.
"""
self._use_spec_dec = True
self._num_tokens_to_verify = num_tokens_to_verify
def reset_use_spec_dec(self) -> None:
"""Reset the usage of speculative decoding for the batch bucket"""
self._use_spec_dec = False
self._num_tokens_to_verify = None
def _make_compact(self) -> None:
# Clean and Compress the batch based on its sequences dict.
# Namely,compress sequences to the front and clean the seq lengths and block tables tensors.
# NOTE Prevent calling this method multiple times in a single step
if self.is_compact:
return
valid_seq_ids = self._sequences_dict.keys()
valid_num = len(valid_seq_ids)
valid_indexes = [self._sequences_indexes[seq_id] for seq_id in valid_seq_ids]
assert valid_num == len(self._sequences_indexes), "BatchBucket indexing is not consistent"
self._sequence_lengths_helper[:valid_num] = self._sequence_lengths[valid_indexes]
self._sequence_lengths[:] = self._sequence_lengths_helper[:]
self._block_tables_helper[:valid_num, :] = self.block_tables[valid_indexes]
self.block_tables[:] = self._block_tables_helper[:]
new_idx = 0
for seq_id in valid_seq_ids:
self._sequences_indexes[seq_id] = new_idx
new_idx += 1
self._sequence_lengths_helper.fill_(0)
self._block_tables_helper.fill_(-1)
self._current_batch_size = valid_num
def add_seq(
self,
seq: Sequence,
alloc_block_table: torch.Tensor = None,
alloc_block_table_fn: Callable[[torch.Tensor, int], None] = None,
) -> Union[torch.Tensor, None]:
"""Add a single sequence to the batch.
User could opt to provide either a block table or a function to allocate block tables.
Args:
seq (Sequence): The sequence to be added to the batch
alloc_block_table (torch.Tensor): The block tables to be copied and used for the sequence
alloc_block_table_fn (Callable[[torch.Tensor, int], None]): The function to allocate blocks for the sequence,
which is expected to reserve blocks and update status of kv-cache manager.
Returns:
block_table (torch.Tensor): The block table of the added sequence, used for block allocation in kv-cache manager.
None if the sequence cannot be added.
"""
block_table = None
# TODO might consider sorting by length
if self._current_batch_size < self.max_batch_size:
self._sequences_dict[seq.request_id] = seq
self._sequences_indexes[seq.request_id] = self._current_batch_size
self._sequence_lengths[self._current_batch_size] = seq.sentence_len
# NOTE the added seq still require block table allocation by kvcache manager
block_table = self._block_tables[self._current_batch_size - 1]
if alloc_block_table is not None:
# copy block ids from provided block tables
self._block_tables[self._current_batch_size - 1] = alloc_block_table
elif alloc_block_table_fn:
alloc_block_table_fn(block_table, self._sequence_lengths[self._current_batch_size - 1].item())
self._current_batch_size += 1
return block_table
def add_seqs(
self,
seqs: List[Sequence],
alloc_block_tables: torch.Tensor = None,
alloc_block_tables_fn: Callable[[torch.Tensor, torch.Tensor], None] = None,
) -> Union[torch.Tensor, None]:
"""Add a list of sequences to the batch.
User could opt to provide either block tables or a function to allocate block tables.
Args:
seqs (List[Sequence]): The sequences to be added to the batch
alloc_block_tables (torch.Tensor): The block tables to be copied and used for the sequence
alloc_block_table_fn (Callable[[torch.Tensor, torch.Tensor], None]): The function to allocate blocks for multiple sequences,
which is expected to reserve blocks and update status of kv-cache manager.
Returns:
block_tables (torch.Tensor): The block tables of the added sequences, used for block allocation in kv-cache manager.
None if the sequences cannot be added.
"""
assert (
alloc_block_tables is None or alloc_block_tables_fn is None
), "`alloc_block_tables` and `alloc_block_tables_fn` cannot be provided at the same time"
num_seqs_to_add = min(self.max_batch_size - self._current_batch_size, len(seqs))
block_tables = None
if num_seqs_to_add > 0:
for i, seq in enumerate(seqs[:num_seqs_to_add]):
self._sequences_dict[seq.request_id] = seq
self._sequences_indexes[seq.request_id] = self._current_batch_size + i
# TODO external (rename): modify Sequence.sentence_len to seq_len
self._sequence_lengths[
self._current_batch_size : self._current_batch_size + num_seqs_to_add
] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
# NOTE block tables to be updated by kvcache manager
block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add]
if alloc_block_tables is not None:
# copy block ids from provided block tables
self._block_tables[
self._current_batch_size : self._current_batch_size + num_seqs_to_add
] = alloc_block_tables
elif alloc_block_tables_fn:
alloc_block_tables_fn(
block_tables,
self._sequence_lengths[self._current_batch_size : self._current_batch_size + num_seqs_to_add],
)
self._current_batch_size += num_seqs_to_add
seqs[:] = seqs[num_seqs_to_add:]
return block_tables
def pop_seq_update_batch(
self, request_id: int, free_block_table_fn: Callable[[torch.Tensor], None] = None
) -> Tuple[Sequence, Union[torch.Tensor, None]]:
"""Pop a single sequence by id from the batch, and update the batch bucket status.
Args:
request_id (int): The uid of the sequence
free_block_table_fn (Callable): The function to free the block table of a sequence,
if not provided, then we have to release the block table manually after calling this method
Returns:
A tuple of: seq (Sequence): The target sequence
and block_table (torch.Tensor): block table of the target sequence indicating corresponding blocks,
none if the sequence is not found or free_block_table_fn is provided.
"""
seq: Sequence = self._sequences_dict.get(request_id)
block_table = None
if seq is not None:
assert request_id in self._sequences_indexes, "Inconsistency in BatchBucket indexing"
self._sequences_dict.pop(request_id)
seq_b_idx = self._sequences_indexes.get(request_id)
if self.current_batch_size > 1:
# replace seq length of the target seq with that of the last seq in the batch
last_seq_b_idx = self.current_batch_size - 1
last_seq_id = next(
(uid for uid, index in self._sequences_indexes.items() if index == last_seq_b_idx),
None,
)
assert last_seq_id is not None
self._sequences_indexes[last_seq_id] = seq_b_idx
self._sequence_lengths[seq_b_idx] = self._sequence_lengths[last_seq_b_idx]
self._sequence_lengths[last_seq_b_idx].fill_(0)
# free the block table of the seq, or return a copy of the block table (to be processed outside)
if free_block_table_fn:
free_block_table_fn(self._block_tables[seq_b_idx])
else:
block_table = self._block_tables[seq_b_idx].detach().clone()
# replace block table of the target seq with that of the last seq in the batch
self._block_tables[seq_b_idx] = self._block_tables[last_seq_b_idx]
self._block_tables[last_seq_b_idx].fill_(-1)
else:
if free_block_table_fn:
free_block_table_fn(self._block_tables[0])
else:
block_table = self._block_tables[0].detach().clone()
self._sequence_lengths[0].fill_(0)
self._block_tables[0].fill_(-1)
self._sequences_indexes.pop(request_id)
self._current_batch_size -= 1
return seq, block_table
def pop_seqs(
self, request_ids: List[int], free_block_table_fn: Callable[[torch.Tensor], None] = None
) -> Tuple[List[Sequence], List[torch.Tensor]]:
"""Iteratively pop a list of sequences by uid.
Args:
request_ids (List[int]): The uids of the sequences
free_block_table_fn (Callable): The function to free the block table of a sequence,
if not provided, then we have to release the block table manually after calling this method
Returns:
A tuple of: seqs (List[Sequence]): The target sequences
and block_tables (List[torch.Tensor]): block tables of the target sequences indicating corresponding blocks
"""
seqs = []
block_tables = []
for request_id in request_ids:
seq, block_table = self.pop_seq_update_batch(request_id, free_block_table_fn)
if seq is not None:
seqs.append(seq)
if block_table is not None:
block_tables.append(block_table)
return seqs, block_tables
def pop_n_seqs(
self, n: int, free_block_table_fn: Callable[[torch.Tensor], None] = None
) -> Tuple[List[Sequence], List[torch.Tensor]]:
"""Pop the first n sequences in the batch (FIFO).
If n is greater than the current batch szie, pop all the sequences in the batch.
Args:
n (int): The number of sequences to pop out
free_block_table_fn (Callable): The function to free the block table of a single sequence
Returns:
A tuple of: seqs (List[Sequence]): The target sequences,
and block_tables (List[torch.Tensor]): block tables of the target sequences indicating corresponding blocks
"""
# NOTE Prevent calling this method multiple times in a single step
seqs = []
block_tables = []
n = min(n, self.current_batch_size)
seq_ids = list(self._sequences_dict.keys())[:n]
for seq_id in seq_ids:
seq = self._sequences_dict.pop(seq_id)
seq_b_idx = self._sequences_indexes.pop(seq_id)
if free_block_table_fn:
free_block_table_fn(self.block_tables[seq_b_idx])
else:
block_tables.append(self.block_tables[seq_b_idx].detach().clone())
seqs.append(seq)
if not self.is_compact:
self._make_compact()
return seqs, block_tables
def pop_finished(
self, free_block_table_fn: Callable[[torch.Tensor], None] = None
) -> Tuple[List[Sequence], List[torch.Tensor]]:
"""Pop finished sequences in the batch and a list of block tables of the finished sequences,
if free_block_table_fn is not provided.
Args:
free_block_table_fn (Callable): The function to free the block table of a single sequence
Returns:
A tuple of: finished_seqs (List[Sequence]): The finished sequences,
and finished_block_tables (List[torch.Tensor]): block tables of the finished sequences.
"""
finished_seqs = []
finished_block_tables = []
for seq in self._sequences_dict.values():
if seq.check_finish():
finished_seqs.append(seq)
# Use `pop_seq_update_batch`` to update the batch status for just a few of finished seqs,
# otherwise, pop seqs directly and then call `_make_compact` to compress the batch.
# For now, the performance difference is not significant, so we use the frist method to pop seqs.
# Precise evaluations to be done.
for seq in finished_seqs:
_, block_table = self.pop_seq_update_batch(seq.request_id, free_block_table_fn)
if block_table is not None:
finished_block_tables.append(block_table)
return finished_seqs, finished_block_tables
# TODO arg type not support beam search sampling yet
def append_batch_tokens(self, tokens: torch.Tensor) -> None:
"""Append a batch of tokens to the sequences in the batch"""
assert self.current_batch_size == tokens.size(0), "Batch size mismatch"
if self.current_batch_size > 0:
tokens = tokens.tolist()
for seq_id, seq in self._sequences_dict.items():
index_in_b = self._sequences_indexes[seq_id]
curr_tokens = tokens[index_in_b]
if not isinstance(curr_tokens, list):
curr_tokens = [curr_tokens]
seq.output_token_id += curr_tokens
seq.check_finish()
self._sequence_lengths[: self.current_batch_size] += 1
def revoke_batch_tokens(self, n_tokens: int, n_seqs: int = 1) -> None:
"""Revoke the last n output tokens of the sequences in the batch
Args:
n_tokens (int): The number of output tokens to revoke from each sequence.
It does not count in the context tokens (input tokens).
n_seqs (int): The first n sequences to revoke tokens from. Defaults to 1.
For now, speculative decoding only supports batch size 1.
"""
if n_tokens >= 1:
seqs_iter = iter(self._sequences_dict.items())
for _ in range(n_seqs):
seq_id, seq = next(seqs_iter)
assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence"
seq.output_token_id = seq.output_token_id[:-n_tokens]
seq.revoke_finished_status()
self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
"""Clear all the sequences in the batch.
free_block_tables_fn (Optional[Callable]): The function to free the block tables of all the sequences in a batch
"""
seqs = list(self._sequences_dict.values())
self._sequences_dict.clear()
self._sequences_indexes.clear()
if free_block_tables_fn:
free_block_tables_fn(self.block_tables, self._current_batch_size)
self._block_tables.fill_(-1)
self._sequence_lengths.fill_(0)
self._current_batch_size = 0
return seqs
def merge(self, other: "BatchBucket") -> List[int]:
"""Merge the sequences in the other batch into the current batch.
Merge as possible as the current batch can, if it does not have available spaces
holding all the sequences in the other batch
Usage:
> New incoming sequence added to prefil batch
prefill bb curr batch size < prefil_ratio * prefill bb max batch size
> New incoming sequence added to prefil batch
prefill bb curr batch size == prefil_ratio * prefill bb max batch size
> Pause Decoding
> Prefill
> Move sequences in prefill bb => decoding bb
> Put back the out-of-volume sequences into the running pool
Returns:
unmerged_ids (List[int]): a list of sequence uids that are not merged into the current batch
"""
unmerged_ids = []
num_seqs_to_merge = min(self.available_batch_size, other.current_batch_size)
if num_seqs_to_merge > 0:
seqs, block_tables_li = other.pop_n_seqs(num_seqs_to_merge)
block_tables = torch.stack(block_tables_li)
self.add_seqs(seqs, alloc_block_tables=block_tables)
unmerged_ids = other.seqs_ids
return unmerged_ids
########## The following methods are expected to be used in modeling ###########
# For compatibility.
# NOTE: This is an assumption way to determine the stage of the batch.
@property
def is_prompts(self) -> bool:
assert len(self._sequences_dict) > 0, "No sequence in the batch"
first_seq = next(iter(self._sequences_dict.values()))
if first_seq.output_len == 0:
return True
return False
def get_1D_inputs_spec_dec(self, n: int) -> torch.Tensor:
# Used for main model verification in **Decoding Stage**
# `n` is the number of tokens to be verified,
# and so that prepare the last `n` tokens of each sequence as the inputs
assert len(self._sequences_dict) > 0, "No sequence in the batch"
assert all(
seq.output_len >= n for seq in self._sequences_dict.values()
), "Sequence output tokens must be greater than or equal to the number of tokens to be verified."
out_li = []
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
for seq_id in seq_ids:
seq: Sequence = self._sequences_dict[seq_id]
out_li.extend(seq.output_token_id[-n:])
return torch.tensor(out_li, dtype=torch.long, device=self.device)
# For compatibility
def get_1D_inputs(self) -> torch.Tensor:
assert len(self._sequences_dict) > 0, "No sequence in the batch"
first_seq = next(iter(self._sequences_dict.values())) # not exactly the first sequence
if first_seq.output_len == 0:
# Assume prefill stage
assert all(
seq.output_len == 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
out_li = []
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
for seq_id in seq_ids:
seq: Sequence = self._sequences_dict[seq_id]
out_li.extend(seq.input_token_id)
return torch.tensor(out_li, dtype=torch.long, device=self.device)
else:
# Assume decoding stage
if self.use_spec_dec:
# For Speculative Decoding
# the number of tokens to be verified in parallel plus the correct token in the last step
return self.get_1D_inputs_spec_dec(self.num_tokens_to_verify + 1)
assert all(
seq.output_len > 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
assert self.is_compact, "BatchBucket is not compact"
out = torch.empty([self.current_batch_size], dtype=torch.long)
for seq_id, index_in_b in self._sequences_indexes.items():
seq: Sequence = self._sequences_dict[seq_id]
out[index_in_b] = seq.output_token_id[-1]
return out.to(device=self.device)
# For compatibility
def get_block_table_tensor(self) -> torch.Tensor:
assert self.is_compact # Debug usage
block_table = self.block_tables[: self.current_batch_size]
return block_table.to(device=self.device)
# For compatibility
def get_sequence_lengths(self) -> torch.Tensor:
assert self.is_compact # Debug usage
sequence_lengths = self.seq_lengths[: self.current_batch_size]
return sequence_lengths.to(device=self.device)
# For compatibility
@property
def fd_inter_tensor(self) -> None:
assert self.fd_interm_tensor is not None, "fd_interm_tensor is not provided"
return self.fd_interm_tensor
def __repr__(self) -> str:
return f"(sequences_dict={self._sequences_dict}, is_prompts={self.is_prompts})"