Browse Source

[Inference] Optimize and Refactor Inference Batching/Scheduling (#5367)

* add kvcache manager funcs for batching

* add batch bucket for batching

* revise RunningList struct in handler

* add kvcache/batch funcs for compatibility

* use new batching methods

* fix indexing bugs

* revise abort logic

* use cpu seq lengths/block tables

* rm unused attr in Sequence

* fix type conversion/default arg

* add and revise pytests

* revise pytests, rm unused tests

* rm unused statements

* fix pop finished indexing issue

* fix: use index in batch when retrieving inputs/update seqs

* use dict instead of odict in batch struct

* arg type hinting

* fix make compress

* refine comments

* fix: pop_n_seqs to pop the first n seqs

* add check in request handler

* remove redundant conversion

* fix test for request handler

* fix pop method in batch bucket

* fix prefill adding
pull/5399/head
Yuanheng Zhao 9 months ago committed by GitHub
parent
commit
b21aac5bae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 449
      colossalai/inference/batch_bucket.py
  2. 2
      colossalai/inference/config.py
  3. 10
      colossalai/inference/core/engine.py
  4. 200
      colossalai/inference/core/request_handler.py
  5. 166
      colossalai/inference/kv_cache/kvcache_manager.py
  6. 8
      colossalai/inference/modeling/models/nopadding_llama.py
  7. 2
      colossalai/inference/struct.py
  8. 140
      tests/test_infer/test_batch_bucket.py
  9. 3
      tests/test_infer/test_config_and_struct.py
  10. 14
      tests/test_infer/test_kvcache_manager.py
  11. 26
      tests/test_infer/test_request_handler.py

449
colossalai/inference/batch_bucket.py

@ -0,0 +1,449 @@
from typing import Callable, List, Optional, Tuple, Union
import torch
from colossalai.inference.struct import Sequence
from colossalai.utils import get_current_device
class BatchBucket:
"""Container for a batch of Sequences, which is used to manage the batch of sequences.
Attrs:
_sequences_dict (Dict[int, Sequence]): Map sequence uid to sequence struct
seq_uid -> Sequence
_sequences_indexes (Dict[int, int]): Map sequence uid to index in the batch
seq_uid -> index in the batch (indexing used in sequence_lengths and block_tables)
_sequence_lengths (torch.Tensor): Length of each sequence in the batch.
The size of the tensor is (max_batch_size,)
_block_tables (torch.Tensor): Block table of each sequence in the batch
The size of the tensor is (max_batch_size, max_blocks_per_seq)
"""
def __init__(
self,
num_heads,
head_dim,
max_batch_size,
max_length,
block_size,
kv_max_split_num,
fd_interm_tensor=None,
device=None,
dtype=torch.float16,
):
self.num_heads = num_heads
self.head_dim = head_dim
self.max_batch_size = max_batch_size
self.max_length = max_length # in + out len
self.block_size = block_size
self.kv_max_split_num = kv_max_split_num # Hint used for flash decoding
self.fd_interm_tensor = fd_interm_tensor
self.device = device or get_current_device()
self.dtype = dtype
self._current_batch_size = 0
self._sequences_dict = dict()
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32)
self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths)
max_blocks_per_seq = (self.max_length + block_size - 1) // block_size
self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32)
self._block_tables_helper = torch.full_like(self._block_tables, -1)
@property
def is_empty(self):
return self._current_batch_size == 0
@property
def current_batch_size(self):
return self._current_batch_size
@property
def available_batch_size(self):
return self.max_batch_size - self._current_batch_size
@property
def block_tables(self):
return self._block_tables
@property
def seq_lengths(self):
return self._sequence_lengths
@property
def seqs_ids(self):
return list(self._sequences_dict.keys())
@property
def seqs_li(self):
return list(self._sequences_dict.values())
@property
def is_compact(self):
assert len(self._sequences_dict) == len(self._sequences_indexes), "BatchBucket indexing is not consistent"
return (
len(self._sequences_dict)
== torch.nonzero(self._sequence_lengths).view(-1).numel()
== torch.nonzero(self._block_tables[:, 0] >= 0).numel()
)
def _make_compact(self) -> None:
# Clean and Compress the batch based on its sequences dict.
# Namely,compress sequences to the front and clean the seq lengths and block tables tensors.
# NOTE Prevent calling this method multiple times in a single step
if self.is_compact:
return
valid_seq_ids = self._sequences_dict.keys()
valid_num = len(valid_seq_ids)
valid_indexes = [self._sequences_indexes[seq_id] for seq_id in valid_seq_ids]
assert valid_num == len(self._sequences_indexes), "BatchBucket indexing is not consistent"
self._sequence_lengths_helper[:valid_num] = self._sequence_lengths[valid_indexes]
self._sequence_lengths[:] = self._sequence_lengths_helper[:]
self._block_tables_helper[:valid_num, :] = self.block_tables[valid_indexes]
self.block_tables[:] = self._block_tables_helper[:]
new_idx = 0
for seq_id in valid_seq_ids:
self._sequences_indexes[seq_id] = new_idx
new_idx += 1
self._sequence_lengths_helper.fill_(0)
self._block_tables_helper.fill_(-1)
self._current_batch_size = valid_num
def add_seq(
self,
seq: Sequence,
alloc_block_table: torch.Tensor = None,
alloc_block_table_fn: Callable[[torch.Tensor, int], None] = None,
) -> Union[torch.Tensor, None]:
"""Add a single sequence to the batch.
User could opt to provide either a block table or a function to allocate block tables.
Args:
seq (Sequence): The sequence to be added to the batch
alloc_block_table (torch.Tensor): The block tables to be copied and used for the sequence
alloc_block_table_fn (Callable[[torch.Tensor, int], None]): The function to allocate blocks for the sequence,
which is expected to reserve blocks and update status of kv-cache manager.
Returns:
block_table (torch.Tensor): The block table of the added sequence, used for block allocation in kv-cache manager.
None if the sequence cannot be added.
"""
block_table = None
# TODO might consider sorting by length
if self._current_batch_size < self.max_batch_size:
self._sequences_dict[seq.request_id] = seq
self._sequences_indexes[seq.request_id] = self._current_batch_size
self._sequence_lengths[self._current_batch_size] = seq.sentence_len
# NOTE the added seq still require block table allocation by kvcache manager
block_table = self._block_tables[self._current_batch_size - 1]
if alloc_block_table is not None:
# copy block ids from provided block tables
self._block_tables[self._current_batch_size - 1] = alloc_block_table
elif alloc_block_table_fn:
alloc_block_table_fn(block_table, self._sequence_lengths[self._current_batch_size - 1].item())
self._current_batch_size += 1
return block_table
def add_seqs(
self,
seqs: List[Sequence],
alloc_block_tables: torch.Tensor = None,
alloc_block_tables_fn: Callable[[torch.Tensor, torch.Tensor], None] = None,
) -> Union[torch.Tensor, None]:
"""Add a list of sequences to the batch.
User could opt to provide either block tables or a function to allocate block tables.
Args:
seqs (List[Sequence]): The sequences to be added to the batch
alloc_block_tables (torch.Tensor): The block tables to be copied and used for the sequence
alloc_block_table_fn (Callable[[torch.Tensor, torch.Tensor], None]): The function to allocate blocks for multiple sequences,
which is expected to reserve blocks and update status of kv-cache manager.
Returns:
block_tables (torch.Tensor): The block tables of the added sequences, used for block allocation in kv-cache manager.
None if the sequences cannot be added.
"""
assert (
alloc_block_tables is None or alloc_block_tables_fn is None
), "`alloc_block_tables` and `alloc_block_tables_fn` cannot be provided at the same time"
num_seqs_to_add = min(self.max_batch_size - self._current_batch_size, len(seqs))
block_tables = None
if num_seqs_to_add > 0:
for i, seq in enumerate(seqs[:num_seqs_to_add]):
self._sequences_dict[seq.request_id] = seq
self._sequences_indexes[seq.request_id] = self._current_batch_size + i
# TODO external (rename): modify Sequence.sentence_len to seq_len
self._sequence_lengths[
self._current_batch_size : self._current_batch_size + num_seqs_to_add
] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
# NOTE block tables to be updated by kvcache manager
block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add]
if alloc_block_tables is not None:
# copy block ids from provided block tables
self._block_tables[
self._current_batch_size : self._current_batch_size + num_seqs_to_add
] = alloc_block_tables
elif alloc_block_tables_fn:
alloc_block_tables_fn(
block_tables,
self._sequence_lengths[self._current_batch_size : self._current_batch_size + num_seqs_to_add],
)
self._current_batch_size += num_seqs_to_add
seqs[:] = seqs[num_seqs_to_add:]
return block_tables
def pop_seq_update_batch(
self, request_id: int, free_block_table_fn: Callable[[torch.Tensor], None] = None
) -> Tuple[Sequence, Union[torch.Tensor, None]]:
"""Pop a single sequence by id from the batch, and update the batch bucket status.
Args:
request_id (int): The uid of the sequence
free_block_table_fn (Callable): The function to free the block table of a sequence,
if not provided, then we have to release the block table manually after calling this method
Returns:
A tuple of: seq (Sequence): The target sequence
and block_table (torch.Tensor): block table of the target sequence indicating corresponding blocks,
none if the sequence is not found or free_block_table_fn is provided.
"""
seq: Sequence = self._sequences_dict.get(request_id)
block_table = None
if seq is not None:
assert request_id in self._sequences_indexes, "Inconsistency in BatchBucket indexing"
self._sequences_dict.pop(request_id)
seq_b_idx = self._sequences_indexes.get(request_id)
if self.current_batch_size > 1:
# replace seq length of the target seq with that of the last seq in the batch
last_seq_b_idx = self.current_batch_size - 1
last_seq_id = next(
(uid for uid, index in self._sequences_indexes.items() if index == last_seq_b_idx),
None,
)
assert last_seq_id is not None
self._sequences_indexes[last_seq_id] = seq_b_idx
self._sequence_lengths[seq_b_idx] = self._sequence_lengths[last_seq_b_idx]
self._sequence_lengths[last_seq_b_idx].fill_(0)
# free the block table of the seq, or return a copy of the block table (to be processed outside)
if free_block_table_fn:
free_block_table_fn(self._block_tables[seq_b_idx])
else:
block_table = self._block_tables[seq_b_idx].detach().clone()
# replace block table of the target seq with that of the last seq in the batch
self._block_tables[seq_b_idx] = self._block_tables[last_seq_b_idx]
self._block_tables[last_seq_b_idx].fill_(-1)
else:
if free_block_table_fn:
free_block_table_fn(self._block_tables[0])
else:
block_table = self._block_tables[0].detach().clone()
self._sequence_lengths[0].fill_(0)
self._block_tables[0].fill_(-1)
self._sequences_indexes.pop(request_id)
self._current_batch_size -= 1
return seq, block_table
def pop_seqs(
self, request_ids: List[int], free_block_table_fn: Callable[[torch.Tensor], None] = None
) -> Tuple[List[Sequence], List[torch.Tensor]]:
"""Iteratively pop a list of sequences by uid.
Args:
request_ids (List[int]): The uids of the sequences
free_block_table_fn (Callable): The function to free the block table of a sequence,
if not provided, then we have to release the block table manually after calling this method
Returns:
A tuple of: seqs (List[Sequence]): The target sequences
and block_tables (List[torch.Tensor]): block tables of the target sequences indicating corresponding blocks
"""
seqs = []
block_tables = []
for request_id in request_ids:
seq, block_table = self.pop_seq_update_batch(request_id, free_block_table_fn)
if seq is not None:
seqs.append(seq)
if block_table is not None:
block_tables.append(block_table)
return seqs, block_tables
def pop_n_seqs(
self, n: int, free_block_table_fn: Callable[[torch.Tensor], None] = None
) -> Tuple[List[Sequence], List[torch.Tensor]]:
"""Pop the first n sequences in the batch (FIFO).
If n is greater than the current batch szie, pop all the sequences in the batch.
Args:
n (int): The number of sequences to pop out
free_block_table_fn (Callable): The function to free the block table of a single sequence
Returns:
A tuple of: seqs (List[Sequence]): The target sequences,
and block_tables (List[torch.Tensor]): block tables of the target sequences indicating corresponding blocks
"""
# NOTE Prevent calling this method multiple times in a single step
seqs = []
block_tables = []
n = min(n, self.current_batch_size)
seq_ids = list(self._sequences_dict.keys())[:n]
for seq_id in seq_ids:
seq = self._sequences_dict.pop(seq_id)
seq_b_idx = self._sequences_indexes.pop(seq_id)
if free_block_table_fn:
free_block_table_fn(self.block_tables[seq_b_idx])
else:
block_tables.append(self.block_tables[seq_b_idx].detach().clone())
seqs.append(seq)
if not self.is_compact:
self._make_compact()
return seqs, block_tables
def pop_finished(
self, free_block_table_fn: Callable[[torch.Tensor], None] = None
) -> Tuple[List[Sequence], List[torch.Tensor]]:
"""Pop finished sequences in the batch and a list of block tables of the finished sequences,
if free_block_table_fn is not provided.
Args:
free_block_table_fn (Callable): The function to free the block table of a single sequence
Returns:
A tuple of: finished_seqs (List[Sequence]): The finished sequences,
and finished_block_tables (List[torch.Tensor]): block tables of the finished sequences.
"""
finished_seqs = []
finished_block_tables = []
for seq in self._sequences_dict.values():
if seq.check_finish():
finished_seqs.append(seq)
# Use `pop_seq_update_batch`` to update the batch status for just a few of finished seqs,
# otherwise, pop seqs directly and then call `_make_compact` to compress the batch.
# For now, the performance difference is not significant, so we use the frist method to pop seqs.
# Precise evaluations to be done.
for seq in finished_seqs:
_, block_table = self.pop_seq_update_batch(seq.request_id, free_block_table_fn)
if block_table is not None:
finished_block_tables.append(block_table)
return finished_seqs, finished_block_tables
# TODO arg type not support beam search sampling yet
def append_batch_tokens(self, tokens: torch.Tensor) -> None:
"""Append a batch of tokens to the sequences in the batch"""
assert self.current_batch_size == tokens.size(0), "Batch size mismatch"
if self.current_batch_size > 0:
tokens = tokens.tolist()
for seq_id, seq in self._sequences_dict.items():
index_in_b = self._sequences_indexes[seq_id]
curr_tokens = tokens[index_in_b]
if not isinstance(curr_tokens, list):
curr_tokens = [curr_tokens]
seq.output_token_id += curr_tokens
seq.check_finish()
self._sequence_lengths[: self.current_batch_size] += 1
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
"""Clear all the sequences in the batch.
free_block_tables_fn (Optional[Callable]): The function to free the block tables of all the sequences in a batch
"""
seqs = list(self._sequences_dict.values())
self._sequences_dict.clear()
self._sequences_indexes.clear()
if free_block_tables_fn:
free_block_tables_fn(self.block_tables, self._current_batch_size)
self._block_tables.fill_(-1)
self._sequence_lengths.fill_(0)
self._current_batch_size = 0
return seqs
def merge(self, other: "BatchBucket") -> List[int]:
"""Merge the sequences in the other batch into the current batch.
Merge as possible as the current batch can, if it does not have available spaces
holding all the sequences in the other batch
Usage:
> New incoming sequence added to prefil batch
prefill bb curr batch size < prefil_ratio * prefill bb max batch size
> New incoming sequence added to prefil batch
prefill bb curr batch size == prefil_ratio * prefill bb max batch size
> Pause Decoding
> Prefill
> Move sequences in prefill bb => decoding bb
> Put back the out-of-volume sequences into the running pool
Returns:
unmerged_ids (List[int]): a list of sequence uids that are not merged into the current batch
"""
unmerged_ids = []
num_seqs_to_merge = min(self.available_batch_size, other.current_batch_size)
if num_seqs_to_merge > 0:
seqs, block_tables_li = other.pop_n_seqs(num_seqs_to_merge)
block_tables = torch.stack(block_tables_li)
self.add_seqs(seqs, alloc_block_tables=block_tables)
unmerged_ids = other.seqs_ids
return unmerged_ids
########## The following methods are expected to be used in modeling ###########
# For compatibility.
# NOTE: This is an assumption way to determine the stage of the batch.
@property
def is_prompts(self) -> bool:
assert len(self._sequences_dict) > 0, "No sequence in the batch"
first_seq = next(iter(self._sequences_dict.values()))
if first_seq.output_len == 0:
return True
return False
# For compatibility
def get_1D_inputs(self) -> torch.Tensor:
assert len(self._sequences_dict) > 0, "No sequence in the batch"
first_seq = next(iter(self._sequences_dict.values())) # not exactly the first sequence
if first_seq.output_len == 0:
# Assume prefill stage
assert all(
seq.output_len == 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
out_li = []
num_tokens = torch.sum(self._sequence_lengths)
out = torch.empty([num_tokens], dtype=torch.long)
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
for seq_id in seq_ids:
seq: Sequence = self._sequences_dict[seq_id]
out_li.extend(seq.input_token_id)
return torch.tensor(out_li, dtype=torch.long, device=self.device)
else:
# Assume decoding stage
assert all(
seq.output_len > 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
assert self.is_compact, "BatchBucket is not compact"
out = torch.empty([self.current_batch_size], dtype=torch.long)
for seq_id, index_in_b in self._sequences_indexes.items():
seq: Sequence = self._sequences_dict[seq_id]
out[index_in_b] = seq.output_token_id[-1]
return out.to(device=self.device)
# For compatibility
def get_block_table_tensor(self) -> torch.Tensor:
assert self.is_compact # Debug usage
block_table = self.block_tables[: self.current_batch_size]
return block_table.to(device=self.device)
# For compatibility
def get_sequence_lengths(self) -> torch.Tensor:
assert self.is_compact # Debug usage
sequence_lengths = self.seq_lengths[: self.current_batch_size]
return sequence_lengths.to(device=self.device)
# For compatibility
@property
def fd_inter_tensor(self) -> None:
assert self.fd_interm_tensor is not None, "fd_interm_tensor is not provided"
return self.fd_interm_tensor

2
colossalai/inference/config.py

@ -109,7 +109,7 @@ class InferenceConfig:
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"
# check distributed
assert (
assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or (
self.tp_size * self.pp_size == dist.get_world_size()
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
# check prompt template

10
colossalai/inference/core/engine.py

@ -42,7 +42,7 @@ class InferenceEngine:
def __init__(
self,
model: nn.Module,
tokenizer: [Union[PreTrainedTokenizer, PreTrainedTokenizerFast]],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: InferenceConfig,
verbose: bool = False,
model_policy: Policy = None,
@ -254,20 +254,12 @@ class InferenceEngine:
else:
prompt = prompts[i]
max_blocks_per_sequence = (
self.inference_config.max_input_len
+ self.inference_config.max_output_len
+ self.inference_config.block_size
- 1
) // self.inference_config.block_size
block_table = torch.full([max_blocks_per_sequence], -1, device=self.device)
sequence = Sequence(
request_id,
prompt,
prompts_token_ids[i],
block_size,
None,
block_table,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
self.inference_config.max_output_len,

200
colossalai/inference/core/request_handler.py

@ -1,15 +1,16 @@
from typing import List
from typing import Dict, List, Union
import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.generation import GenerationConfig
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import *
from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence
from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.logging import get_dist_logger
__all__ = ["RunningList", "RequestHandler"]
@ -24,45 +25,79 @@ class RunningList:
Args:
prefill_ratio: (float) A ratio for determing whether to perform prefill or not.
prefill: (List) List that contains default inputs, defaults to [].
_prefill (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence.
_decoding (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence.
"""
def __init__(self, prefill_ratio: str, prefill: List[Sequence] = None):
def __init__(self, prefill_ratio: int, prefill: List[Sequence] = None) -> None:
self.prefill_ratio = prefill_ratio
self.decoding: List[Sequence] = []
self.prefill: List[Sequence] = prefill if prefill is not None else []
self._decoding: Dict[int, Sequence] = dict()
self._prefill: Dict[int, Sequence] = (
dict({seq.request_id: seq for seq in self._prefill}) if prefill is not None else dict()
)
def append(self, seq: Sequence):
# add seq to prefilling list first.
self.prefill.append(seq)
def find_seq(self, request_id):
for seq in self.decoding:
if request_id == seq.request_id:
return seq
for seq in self.prefill:
if request_id == seq.request_id:
return seq
return None
@property
def decoding(self):
return list(self._decoding.values())
@property
def prefill(self):
return list(self._prefill.values())
@property
def prefill_seq_num(self):
return len(self._prefill)
@property
def decoding_seq_num(self):
return len(self._decoding)
@property
def total_seq_num(self):
return self.prefill_seq_num + self.decoding_seq_num
def remove(self, seq: Sequence):
if seq in self.decoding:
self.decoding.remove(seq)
elif seq in self.prefill:
self.prefill.remove(seq)
def append(self, seq: Sequence):
assert (seq.request_id not in self._prefill) and (
seq.request_id not in self._decoding
), f"Sequence uid {seq.request_id} already exists."
self._prefill[seq.request_id] = seq
def extend(self, seqs: List[Sequence]):
for seq in seqs:
self._prefill[seq.request_id] = seq
def find_seq(self, request_id) -> Union[Sequence, None]:
seq = None
if request_id in self._decoding:
seq = self._decoding[request_id]
elif request_id in self._prefill:
seq = self._prefill[request_id]
return seq
def remove(self, seq: Sequence) -> None:
if seq.request_id in self._decoding:
self._decoding.pop(seq.request_id)
elif seq.request_id in self._prefill:
self._prefill.pop(seq.request_id)
else:
raise ValueError(f"sequence {seq.request_id} is not in running list")
raise ValueError(f"Sequence {seq.request_id} is not in running list")
def ready_for_prefill(self):
if not self.decoding:
return len(self.prefill) > 0
return len(self.prefill) / len(self.decoding) >= self.prefill_ratio
if not self._decoding:
return len(self._prefill) > 0
return len(self._prefill) / len(self._decoding) >= self.prefill_ratio
def is_empty(self):
return not self.decoding and not self.prefill
return not self._decoding and not self._prefill
def total_seq_num(self):
return len(self.decoding) + len(self.prefill)
def mark_prefill_running(self) -> None:
for seq_id in self._prefill:
self._prefill[seq_id].mark_running()
def move_prefill_to_decoding(self, seq_ids: List[int]) -> None:
for seq_id in seq_ids:
assert seq_id in self._prefill, f"Sequence {seq_id} is not in prefill list"
self._decoding[seq_id] = self._prefill.pop(seq_id)
class RequestHandler:
@ -110,25 +145,27 @@ class RequestHandler:
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
# which may cause bugs and this issue should be fixed later.
self.running_batch = BatchInfo(
max_batch_size=self.max_batch_size,
kv_max_split_num=kv_max_split_num,
self.running_bb = BatchBucket(
num_heads=model_config.num_attention_heads,
head_dim=head_dim,
is_prompts=False,
device=device,
dtype=self.dtype,
fd_inter_tensor=fd_inter_tensor,
)
self.prefill_batch = BatchInfo(
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
block_size=inference_config.block_size,
kv_max_split_num=kv_max_split_num,
fd_interm_tensor=fd_inter_tensor,
dtype=self.dtype,
device=device,
)
self.prefill_bb = BatchBucket(
num_heads=model_config.num_attention_heads,
head_dim=head_dim,
is_prompts=True,
device=device,
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
block_size=inference_config.block_size,
kv_max_split_num=kv_max_split_num,
fd_interm_tensor=fd_inter_tensor,
dtype=self.dtype,
fd_inter_tensor=fd_inter_tensor,
device=device,
)
def _init_cache(self, model_config):
@ -159,40 +196,39 @@ class RequestHandler:
remove_list.append(seq)
break
# stop feeding new sequence into running list to assure
if self.cache_manager.num_available_blocks <= self.running_list.total_seq_num():
break
num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num)
remove_list.extend(lst[:num_seqs_to_add])
self.running_list.extend(lst[:num_seqs_to_add])
# Try to allocate cache blocks for the sequence.
if (
self.cache_manager.check_allocation(seq)
and (len(self.running_list.prefill) + len(self.running_list.decoding))
< self.max_batch_size # There some bugs in continous batching, so we disable it here.
):
# If succeed, add the sequence to running list.
remove_list.append(seq)
self.running_list.append(seq)
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.sentence_len)
for seq in remove_list:
lst.remove(seq)
if self.running_list.ready_for_prefill():
for seq in self.running_list.prefill:
seq.mark_running()
self.prefill_batch.add_seqs(self.running_list.prefill)
return self.prefill_batch
num_seqs_to_add = min(self.running_list.prefill_seq_num, self.running_bb.available_batch_size)
if not self.running_batch.is_empty:
for seq in self.running_batch.sequences_set:
recycle = self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
if recycle:
for seq in self.running_list.prefill[:num_seqs_to_add]:
seq.mark_running()
# allocate blocks for the prefill batch
self.prefill_bb.add_seqs(
self.running_list.prefill[:num_seqs_to_add],
alloc_block_tables_fn=self.cache_manager.allocate_context_from_block_tables,
)
return self.prefill_bb
if not self.running_bb.is_empty:
seqs_ids_to_recycle = self.cache_manager.allocate_tokens_from_block_tables(
self.running_bb.block_tables, self.running_bb.seq_lengths, self.running_bb.current_batch_size
)
if seqs_ids_to_recycle:
seqs_to_recycle = self.running_bb.pop_seqs(seqs_ids_to_recycle)
for seq in seqs_to_recycle:
seq.recycle()
self.running_batch.del_seq(seq)
self.running_list.remove(seq)
self.waiting_list[-1].append(seq)
# the recycled sequences are handled with highest priority.
return self.running_batch
return self.running_bb
def add_sequence(self, req: Sequence):
"""
@ -213,7 +249,7 @@ class RequestHandler:
seq.mark_aborted()
self.waiting_list[priority].remove(seq)
elif seq.status.is_running():
self.cache_manager.free_block_table(seq.block_table)
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
self.running_list.remove(seq)
else:
try:
@ -242,7 +278,7 @@ class RequestHandler:
else:
sample_tokens = greedy_sample(generation_config, logprobs)
else:
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_batch.is_empty)
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_bb.is_empty)
return sample_tokens
@ -273,27 +309,25 @@ class RequestHandler:
# sample the next tokens
sample_tokens = self._sample(probs, logprobs, generation_config)
if not self.prefill_batch.is_empty:
self.prefill_batch.update_batch_tokens(sample_tokens)
if not self.prefill_bb.is_empty:
self.prefill_bb.append_batch_tokens(sample_tokens)
else:
self.running_batch.update_batch_tokens(sample_tokens)
self.running_bb.append_batch_tokens(sample_tokens)
def update(self):
"""
Update current running list and done list
"""
if not self.prefill_batch.is_empty:
self.running_list.decoding.extend(self.running_list.prefill)
self.running_batch.add_seqs(self.running_list.prefill)
self.running_list.prefill.clear()
self.prefill_batch.clear_batch()
finish_seqs = self.running_batch.fliter_batch()
for seq in finish_seqs:
if not self.prefill_bb.is_empty:
self.running_list.move_prefill_to_decoding(self.prefill_bb.seqs_ids)
self.running_bb.merge(self.prefill_bb)
# clear the prefill batch without assigning a free_block_tables_fn
# since we want to reuse the memory recorded on the block tables
self.prefill_bb.clear(free_block_tables_fn=None)
finished_seqs, _ = self.running_bb.pop_finished(self.cache_manager.free_block_table)
for seq in finished_seqs:
self.running_list.remove(seq)
self.cache_manager.free_block_table(seq.block_table)
self.done_list.extend(finish_seqs)
self.done_list.extend(finished_seqs)
return finish_seqs
return finished_seqs

166
colossalai/inference/kv_cache/kvcache_manager.py vendored

@ -63,7 +63,6 @@ class KVCacheManager:
self.dtype = config.dtype
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
# For now we focus on MHA only, TODO add handling for MQA and GQA
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
@ -82,8 +81,8 @@ class KVCacheManager:
# Physical cache allocation
alloc_shape = (self.num_blocks, self.head_num, self.block_size, self.head_size)
if verbose:
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
# if verbose:
# self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
self._kv_caches = self._init_device_caches(alloc_shape)
self.total_physical_cache_size_in_bytes = (
self.elem_size_in_bytes
@ -112,6 +111,9 @@ class KVCacheManager:
"""Get the number of available cache blocks."""
return self._available_blocks
def get_head_size(self):
return self.head_size
def get_kv_cache(self):
"""Get k_cache and v_cache"""
return self._kv_caches
@ -148,7 +150,7 @@ class KVCacheManager:
and updates the provided block table with the allocated block ids.
Args:
block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
block_table: A 1D tensor of shape [max_blocks_per_sequence], mapping of token_position_id -> block_id.
context_len: The length of the processing sequnece.
"""
assert block_table.dim() == 1
@ -193,12 +195,85 @@ class KVCacheManager:
else:
self._allocate_on_block(block, block.block_size)
def allocate_context_from_block_tables(self, block_tables: torch.Tensor, context_lengths: torch.Tensor) -> None:
"""Allocate logical cache blocks for a batch of sequences during prefill stage.
Args:
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence]
context_lengths (torch.Tensor): [bsz]]
"""
assert block_tables.dim() == 2
assert block_tables.size(0) == context_lengths.size(0)
if not torch.all(block_tables < 0):
self.logger.error("Some slots on provided block table have been allocated.")
blocks_required = (context_lengths + self.block_size - 1) // self.block_size
num_blocks_required = torch.sum(blocks_required).item()
assert isinstance(num_blocks_required, int)
if num_blocks_required > self._available_blocks:
self.logger.warning(
f"Lacking blocks to allocate. Available blocks {self._available_blocks}; blocks asked {num_blocks_required}."
)
return
bsz = block_tables.size(0)
# Try contiguous allocation
torch.cumsum(self._block_states, dim=-1, out=self._block_states_cum[1:])
torch.subtract(
self._block_states_cum[num_blocks_required:],
self._block_states_cum[:-num_blocks_required],
out=self._block_finder[num_blocks_required - 1 :],
)
end_indexes = torch.nonzero(self._block_finder == num_blocks_required, as_tuple=False).view(-1)
if end_indexes.numel() > 0:
# contiguous cache exists
end_idx = end_indexes[0].item() + 1 # open interval
start_idx = end_idx - num_blocks_required # closed interval
alloc_block_ids = torch.arange(start_idx, end_idx)
for i in range(bsz):
curr_required = blocks_required[i]
block_tables[i, :curr_required] = torch.arange(
start_idx, start_idx + curr_required, device=block_tables.device
)
start_idx += curr_required
else:
# non-contiguous cache
available_block_ids = torch.nonzero(self._block_states > 0).view(-1)
alloc_block_ids = available_block_ids[:num_blocks_required]
alloc_block_ids = alloc_block_ids.to(dtype=block_tables.dtype, device=block_tables.device)
start_idx = 0
for i in range(bsz):
curr_required = blocks_required[i]
block_tables[i, :curr_required] = alloc_block_ids[start_idx, start_idx + curr_required]
start_idx += curr_required
# Update cache blocks
self._block_states[alloc_block_ids] = 0
self._available_blocks -= num_blocks_required
last_block_locs = torch.cumsum(blocks_required, dim=0) - 1
last_block_locs = last_block_locs.to(device=alloc_block_ids.device)
for i, block_id in enumerate(alloc_block_ids[last_block_locs]):
block: CacheBlock = self._cache_blocks[block_id]
block.add_ref()
self._allocate_on_block(
block,
block.block_size
if context_lengths[i] % block.block_size == 0
else context_lengths[i].item() % block.block_size,
)
for block_id in alloc_block_ids:
if block_id in alloc_block_ids[last_block_locs]:
continue
block: CacheBlock = self._cache_blocks[block_id]
block.add_ref()
self._allocate_on_block(block, block.block_size)
def allocate_token_from_block_table(self, block_table: torch.Tensor, context_len: int) -> None:
"""Allocate the logical cache block for a single sequence during decoding stage,
and updates the provided block table if a new cache block is needed.
Args:
block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
block_table: A 1D tensor of shape [max_blocks_per_sequence], mapping of token_position_id -> block_id.
context_len: The length of the processing sequnece (already-allocated length).
"""
assert block_table.dim() == 1
@ -207,12 +282,79 @@ class KVCacheManager:
alloc_local_block_idx = context_len // self.block_size
return self.allocate_single_block(block_table, alloc_local_block_idx)
def allocate_tokens_from_block_tables(
self, block_tables: torch.Tensor, context_lens: torch.Tensor, bsz: int = None
) -> List[int]:
"""Allocate logical cache blocks for a batch of sequences during decoding stage.
Usage:
allocate_context_from_block_tables
model forward (block tables & context lengths passed)
update context lengths
allocate_tokens_from_block_tables
model forward
update context lengths
allocate_tokens_from_block_tables
model forward
update context lengths
...
Args:
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence]
context_lengths (torch.Tensor): [bsz]
Returns:
List[int]: list of sequence uid to be recycled
"""
assert block_tables.dim() == 2
assert context_lens.dim() == 1
bsz = block_tables.size(0) if bsz is None else bsz
alloc_local_block_indexes = (context_lens[:bsz]) // self.block_size
block_global_ids = block_tables[torch.arange(0, bsz), alloc_local_block_indexes]
seqs_to_recycle = []
new_blocks_required = torch.sum(block_global_ids < 0).item()
seqs_req_new_blocks = torch.nonzero(block_global_ids < 0).squeeze()
if new_blocks_required > 0:
if new_blocks_required > self._available_blocks:
# TODO might want to revise the logic here
# Process the first (_available_blocks) sequences that require new blocks
# Put the rest of the sequences back to recycled
seqs_req_new_blocks, seqs_to_recycle = (
seqs_req_new_blocks[: self._available_blocks],
seqs_req_new_blocks[self._available_blocks :],
)
for seq_id in seqs_to_recycle:
self.free_block_table(block_tables[seq_id])
new_blocks_required = self._available_blocks
# NOTE might want to alloc contiguous logic
free_block_ids = torch.nonzero(self._block_states > 0).view(-1)
alloc_block_ids = free_block_ids[:new_blocks_required].to(
dtype=block_tables.dtype, device=block_tables.device
)
for block_id in alloc_block_ids:
block: CacheBlock = self._cache_blocks[block_id]
block.add_ref()
self._block_states[block_id] = 0
self._available_blocks -= 1
block_tables[seqs_req_new_blocks, alloc_local_block_indexes[seqs_req_new_blocks]] = alloc_block_ids
block_global_ids = block_tables[torch.arange(0, bsz), alloc_local_block_indexes]
for block_id in block_global_ids:
self._allocate_on_block(self._cache_blocks[block_id], 1)
return seqs_to_recycle
def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int) -> int:
"""Allocate space asked on a single block in the block table, specified by the provided position id,
and updates the provided block table with the allocated block.
Args:
block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
block_table: A 1D tensor of shape [max_blocks_per_sequence], mapping of token_position_id -> block_id.
block_local_idx: The index of the block in the block table.
space_asked: i.e. The number of tokens to be assigned space for.
Returns:
@ -240,8 +382,7 @@ class KVCacheManager:
def free_block_table(self, block_table: torch.Tensor) -> None:
"""Free the logical cache blocks for **a single sequence**."""
assert block_table.dim() == 1
for i in range(block_table.numel()):
global_block_id = block_table[i].item()
for i, global_block_id in enumerate(block_table.tolist()):
if global_block_id < 0:
return
block: CacheBlock = self._cache_blocks[global_block_id]
@ -253,6 +394,15 @@ class KVCacheManager:
# reset the block id in the block table (if we maintain a 2D tensors as block tables in Engine)
block_table[i] = -1
def free_block_tables(self, block_tables: torch.Tensor, first_n: int = None) -> None:
"""Release the logical cache blocks for a batch of sequences.
If `first_n` is provided, only the blocks for the first several sequences will be released.
"""
assert block_tables.dim() == 2
first_n = block_tables.size(0) if first_n is None else first_n
for block_table in block_tables[:first_n]:
self.free_block_table(block_table)
def clear_all(self) -> None:
"""Clear all the references and allocations on all the cache blocks."""
for block in self._cache_blocks:

8
colossalai/inference/modeling/models/nopadding_llama.py

@ -12,8 +12,8 @@ from transformers.models.llama.modeling_llama import (
LlamaModel,
)
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.struct import BatchInfo
from colossalai.kernel.triton import (
context_attention_unpadded,
copy_kv_to_blocked_cache,
@ -34,7 +34,7 @@ except ImportError:
def llama_causal_lm_forward(
self: LlamaForCausalLM,
batch: BatchInfo = None,
batch: BatchBucket = None,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
):
@ -59,7 +59,7 @@ def llama_causal_lm_forward(
def llama_model_forward(
self: LlamaModel,
batch: BatchInfo = None,
batch: BatchBucket = None,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
):
@ -73,7 +73,7 @@ def llama_model_forward(
input_ids = batch.get_1D_inputs()
block_tables = batch.get_block_table_tensor()
sequence_lengths = batch.get_sequence_lengths()
batch_size = len(sequence_lengths)
batch_size = batch.current_batch_size
kv_seq_len = sequence_lengths.max().item()
hidden_states = self.embed_tokens(input_ids)

2
colossalai/inference/struct.py

@ -71,7 +71,6 @@ class Sequence:
input_token_id: List[int]
block_size: int
sample_params: Any # SampleParams needs to be imported later.
block_table: torch.Tensor
eos_token_id: int
pad_token_id: int
max_output_len: int = 256
@ -158,7 +157,6 @@ class Sequence:
f"prompt={self.prompt}, "
f"status={self.status.name}, "
f"sample_params={self.sample_params}, "
f"logical_block_number={self.block_table.shape[0]},"
f"input_len={self.input_len}),"
f"output_len={self.output_len})"
)

140
tests/test_infer/test_batch_bucket.py

@ -0,0 +1,140 @@
import torch
from transformers.models.llama import LlamaConfig
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.struct import Sequence
from colossalai.testing import parameterize
@parameterize(
"test_config",
[
{
"hidden_size": 128,
"num_attention_heads": 4,
"num_layers": 2,
"block_size": 4,
"max_batch_size": 4,
"max_input_len": 32,
"max_output_len": 8,
"dtype": torch.float16,
"tp_size": 1,
}
],
)
def test_bucket(test_config):
hidden_size = test_config.pop("hidden_size")
num_heads = test_config.pop("num_attention_heads")
num_layers = test_config.pop("num_layers")
model_config = LlamaConfig(
hidden_size=hidden_size,
num_hidden_layers=num_layers,
num_attention_heads=num_heads,
)
inference_config = InferenceConfig(**test_config)
# Just for testing usage. Don't create multiple cache_manager on the same device.
cache_manager = KVCacheManager(inference_config, model_config)
cache_manager_copy = KVCacheManager(inference_config, model_config)
seq_lens = [19, 20, 27]
seq1 = Sequence(
request_id=0,
prompt="", # Dummy for testing usage
input_token_id=list(range(seq_lens[0])),
block_size=4,
sample_params=None,
eos_token_id=2,
pad_token_id=2,
max_output_len=10,
)
seq2 = Sequence(
request_id=1,
prompt="", # Dummy for testing usage
input_token_id=list(range(seq_lens[1])),
block_size=4,
sample_params=None,
eos_token_id=2,
pad_token_id=2,
max_output_len=10,
)
seq3 = Sequence(
request_id=2,
prompt="", # Dummy for testing usage
input_token_id=list(range(seq_lens[2])),
block_size=4,
sample_params=None,
eos_token_id=2,
pad_token_id=2,
max_output_len=10,
)
block_size = test_config["block_size"]
max_batch_size = test_config["max_batch_size"]
max_length = test_config["max_input_len"] + test_config["max_output_len"]
assert max_batch_size >= 2, "max_batch_size should be greater than 1"
bb = BatchBucket(
num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2
)
bb_copy = BatchBucket(
num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2
)
block_tables = bb.add_seqs([seq1, seq2])
assert block_tables.shape == (2, cache_manager.max_blocks_per_sequence)
assert torch.all(block_tables < 0), "Initialized block_tables should be negative values"
cache_manager.allocate_context_from_block_tables(block_tables, bb.seq_lengths[: bb.current_batch_size])
bb_copy.add_seqs(
[seq1, seq2], alloc_block_tables_fn=cache_manager_copy.allocate_context_from_block_tables
) # This is just for testing usage. Don't add the same sequence to different buckets.
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * (
max_batch_size - bb.current_batch_size
)
assert torch.equal(bb.block_tables, bb_copy.block_tables)
bb.append_batch_tokens(torch.tensor([99, 99]))
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * (
max_batch_size - bb.current_batch_size
)
cache_manager.allocate_tokens_from_block_tables(bb.block_tables, bb.seq_lengths, bsz=bb.current_batch_size)
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * (
max_batch_size - bb.current_batch_size
)
bb.append_batch_tokens(torch.tensor([99, 99]))
cache_manager.allocate_tokens_from_block_tables(bb.block_tables, bb.seq_lengths, bsz=bb.current_batch_size)
assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * (
max_batch_size - bb.current_batch_size
)
bb.pop_seq_update_batch(0, free_block_table_fn=cache_manager.free_block_table)
assert bb.seq_lengths.tolist() == [bb.seqs_li[0].sentence_len] + [0] * (max_batch_size - bb.current_batch_size)
assert bb.is_compact
bb2 = BatchBucket(
num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2
)
block_tables = bb2.add_seqs([seq3])
cache_manager.allocate_context_from_block_tables(block_tables, bb2.seq_lengths[: bb2.current_batch_size])
unmerged_ids = bb.merge(bb2)
assert not unmerged_ids
assert bb.is_compact
assert bb2.is_compact
assert bb.current_batch_size == 2
assert bb2.current_batch_size == 0
bb.clear(cache_manager.free_block_tables)
assert bb.current_batch_size == 0
assert bb.is_compact
assert bb.seq_lengths.tolist() == [0] * max_batch_size
assert torch.all(bb.block_tables < 0)
if __name__ == "__main__":
test_bucket()

3
tests/test_infer/test_config_and_struct.py

@ -15,7 +15,6 @@ def check_config_and_inference():
input_token_id=[1, 2, 3],
block_size=16,
sample_params=None,
block_table=None,
eos_token_id=2,
pad_token_id=2,
max_output_len=256,
@ -27,7 +26,6 @@ def check_config_and_inference():
input_token_id=[4, 5, 6],
block_size=16,
sample_params=None,
block_table=None,
eos_token_id=2,
pad_token_id=2,
max_output_len=256,
@ -39,7 +37,6 @@ def check_config_and_inference():
input_token_id=[7, 8, 9],
block_size=16,
sample_params=None,
block_table=None,
eos_token_id=2,
pad_token_id=2,
max_output_len=256,

14
tests/test_infer/test_kvcache_manager.py

@ -148,6 +148,20 @@ def check_cache_manager(test_config):
cache_manager.clear_all()
assert cache_manager.num_available_blocks == num_blocks
for cache_block in cache_manager._cache_blocks:
assert cache_block.available_space == block_size
# Mock batch operations (Prefill/Decoding updates)
context_lengths = torch.tensor([max_input_length, max_input_length - 1])
block_tables = torch.tensor(
[[-1 for _ in range(cache_manager.max_blocks_per_sequence)] for _ in range(2)], dtype=torch.int32
)
cache_manager.allocate_context_from_block_tables(block_tables, context_lengths)
cache_manager.allocate_tokens_from_block_tables(block_tables, context_lengths)
cache_manager.free_block_tables(block_tables)
for cache_block in cache_manager._cache_blocks:
assert cache_block.available_space == block_size
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")

26
tests/test_infer/test_request_handler.py

@ -1,5 +1,4 @@
import pytest
import torch
from transformers.models.llama import LlamaConfig
import colossalai
@ -22,17 +21,35 @@ def check_running_list():
eos_token_id=0,
pad_token_id=0,
sample_params=None,
block_table=1,
)
seq2 = Sequence(
request_id=2,
prompt="abc",
input_token_id=[1, 2, 3],
block_size=16,
eos_token_id=0,
pad_token_id=0,
sample_params=None,
)
running_list.append(seq1)
running_list.append(seq2)
assert running_list.ready_for_prefill()
assert running_list.decoding == [] and running_list.prefill[0] == seq1
assert len(running_list.decoding) == 0
assert len(running_list.prefill) > 0 and running_list.prefill[0] == seq1
seq = running_list.find_seq(seq1.request_id)
assert seq == seq1
running_list.mark_prefill_running()
for seq in running_list.prefill:
assert seq.status == RequestStatus.RUNNING
running_list.move_prefill_to_decoding([seq1.request_id, seq2.request_id])
assert len(running_list.prefill) == 0
assert len(running_list.decoding) > 0 and running_list.decoding[0] == seq1
running_list.remove(seq1)
running_list.remove(seq2)
assert running_list.is_empty()
@ -59,7 +76,6 @@ def check_request_handler():
eos_token_id=0,
pad_token_id=0,
sample_params=None,
block_table=torch.tensor([-1, -1]),
)
request_handler.add_sequence(seq1)
# the priority should be 1

Loading…
Cancel
Save