|
|
|
from typing import List
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
|
|
|
|
|
|
from colossalai.inference.config import InferenceConfig
|
|
|
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
|
|
|
from colossalai.inference.kv_cache import KVCacheManager
|
|
|
|
from colossalai.inference.logit_processors import logit_processor
|
|
|
|
from colossalai.inference.sampler import *
|
|
|
|
from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
|
|
|
|
__all__ = ["RunningList", "RequestHandler"]
|
|
|
|
|
|
|
|
logger = get_dist_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class RunningList:
|
|
|
|
"""
|
|
|
|
RunningList is an structure for recording the running sequences, contains prefill and decoding list.
|
|
|
|
Prefilling samples will be hold until the actual ratio of prefill samples versus decoding samples exceeds ratio.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
prefill_ratio: (float) A ratio for determing whether to perform prefill or not.
|
|
|
|
prefill: (List) List that contains default inputs, defaults to [].
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, prefill_ratio: str, prefill: List[Sequence] = None):
|
|
|
|
self.prefill_ratio = prefill_ratio
|
|
|
|
self.decoding: List[Sequence] = []
|
|
|
|
self.prefill: List[Sequence] = prefill if prefill is not None else []
|
|
|
|
|
|
|
|
def append(self, seq: Sequence):
|
|
|
|
# add seq to prefilling list first.
|
|
|
|
self.prefill.append(seq)
|
|
|
|
|
|
|
|
def find_seq(self, request_id):
|
|
|
|
for seq in self.decoding:
|
|
|
|
if request_id == seq.request_id:
|
|
|
|
return seq
|
|
|
|
for seq in self.prefill:
|
|
|
|
if request_id == seq.request_id:
|
|
|
|
return seq
|
|
|
|
return None
|
|
|
|
|
|
|
|
def remove(self, seq: Sequence):
|
|
|
|
if seq in self.decoding:
|
|
|
|
self.decoding.remove(seq)
|
|
|
|
elif seq in self.prefill:
|
|
|
|
self.prefill.remove(seq)
|
|
|
|
else:
|
|
|
|
raise ValueError(f"sequence {seq.request_id} is not in running list")
|
|
|
|
|
|
|
|
def ready_for_prefill(self):
|
|
|
|
if not self.decoding:
|
|
|
|
return len(self.prefill) > 0
|
|
|
|
return len(self.prefill) / len(self.decoding) >= self.prefill_ratio
|
|
|
|
|
|
|
|
def is_empty(self):
|
|
|
|
return not self.decoding and not self.prefill
|
|
|
|
|
|
|
|
def total_seq_num(self):
|
|
|
|
return len(self.decoding) + len(self.prefill)
|
|
|
|
|
|
|
|
|
|
|
|
class RequestHandler:
|
|
|
|
"""
|
|
|
|
RequestHandler is the core for handling existing requests and updating current batch.
|
|
|
|
During generation process, we call schedule function each iteration to update current batch.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
inference_config: Configuration for initialize and manage kv cache.
|
|
|
|
model_config: Configuration for model
|
|
|
|
dtype (torch.dtype): The data type for weights and activations.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None:
|
|
|
|
self.inference_config = inference_config
|
|
|
|
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
|
|
|
|
self.waiting_list: List[List] = [[], [], []]
|
|
|
|
self.done_list: List[Sequence] = []
|
|
|
|
self.dtype = inference_config.dtype
|
|
|
|
self.max_batch_size = inference_config.max_batch_size
|
|
|
|
|
|
|
|
# initialize cache
|
|
|
|
self._init_cache(model_config)
|
|
|
|
|
|
|
|
# initialize batch
|
|
|
|
device = torch.cuda.current_device()
|
|
|
|
kv_max_split_num = (
|
|
|
|
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
|
|
|
|
) // inference_config.block_size
|
|
|
|
head_dim = model_config.hidden_size // model_config.num_attention_heads
|
|
|
|
|
|
|
|
fd_inter_tensor = FDIntermTensors()
|
|
|
|
fd_inter_tensor.initialize(
|
|
|
|
max_batch_size=self.max_batch_size,
|
|
|
|
num_attn_heads=model_config.num_attention_heads,
|
|
|
|
kv_max_split_num=kv_max_split_num,
|
|
|
|
head_dim=head_dim,
|
|
|
|
dtype=self.dtype,
|
|
|
|
device=device,
|
|
|
|
)
|
|
|
|
|
|
|
|
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
|
|
|
|
# which may cause bugs and this issue should be fixed later.
|
|
|
|
self.running_batch = BatchInfo(
|
|
|
|
max_batch_size=self.max_batch_size,
|
|
|
|
kv_max_split_num=kv_max_split_num,
|
|
|
|
num_heads=model_config.num_attention_heads,
|
|
|
|
head_dim=head_dim,
|
|
|
|
is_prompts=False,
|
|
|
|
device=device,
|
|
|
|
dtype=self.dtype,
|
|
|
|
fd_inter_tensor=fd_inter_tensor,
|
|
|
|
)
|
|
|
|
self.prefill_batch = BatchInfo(
|
|
|
|
max_batch_size=self.max_batch_size,
|
|
|
|
kv_max_split_num=kv_max_split_num,
|
|
|
|
num_heads=model_config.num_attention_heads,
|
|
|
|
head_dim=head_dim,
|
|
|
|
is_prompts=True,
|
|
|
|
device=device,
|
|
|
|
dtype=self.dtype,
|
|
|
|
fd_inter_tensor=fd_inter_tensor,
|
|
|
|
)
|
|
|
|
|
|
|
|
def _init_cache(self, model_config):
|
|
|
|
self.cache_manager = KVCacheManager(self.inference_config, model_config)
|
|
|
|
|
|
|
|
def _has_waiting(self) -> bool:
|
|
|
|
return any(lst for lst in self.waiting_list)
|
|
|
|
|
|
|
|
def get_kvcache(self):
|
|
|
|
return self.cache_manager.get_kv_cache()
|
|
|
|
|
|
|
|
def schedule(self):
|
|
|
|
"""
|
|
|
|
The main logic of request handler.
|
|
|
|
"""
|
|
|
|
if self._has_waiting():
|
|
|
|
# Try to allocate cache blocks for the sequence using a priority of prompt length.
|
|
|
|
for lst in reversed(self.waiting_list):
|
|
|
|
if lst:
|
|
|
|
remove_list = []
|
|
|
|
for seq in lst:
|
|
|
|
if seq.input_len > self.inference_config.max_input_len:
|
|
|
|
# If the prompt length is longer than max_input_len, abort the sequence.
|
|
|
|
logger.warning(
|
|
|
|
f"the prompt(Request id = {seq.request_id}) length is longer than max_input_len, abort this sequence."
|
|
|
|
)
|
|
|
|
self.abort_sequence(seq.request_id)
|
|
|
|
remove_list.append(seq)
|
|
|
|
break
|
|
|
|
|
|
|
|
# stop feeding new sequence into running list to assure
|
|
|
|
if self.cache_manager.num_available_blocks <= self.running_list.total_seq_num():
|
|
|
|
break
|
|
|
|
|
|
|
|
# Try to allocate cache blocks for the sequence.
|
|
|
|
if (
|
|
|
|
self.cache_manager.check_allocation(seq)
|
|
|
|
and (len(self.running_list.prefill) + len(self.running_list.decoding))
|
|
|
|
< self.max_batch_size # There some bugs in continous batching, so we disable it here.
|
|
|
|
):
|
|
|
|
# If succeed, add the sequence to running list.
|
|
|
|
remove_list.append(seq)
|
|
|
|
self.running_list.append(seq)
|
|
|
|
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.sentence_len)
|
|
|
|
for seq in remove_list:
|
|
|
|
lst.remove(seq)
|
|
|
|
if self.running_list.ready_for_prefill():
|
|
|
|
for seq in self.running_list.prefill:
|
|
|
|
seq.mark_running()
|
|
|
|
self.prefill_batch.add_seqs(self.running_list.prefill)
|
|
|
|
return self.prefill_batch
|
|
|
|
|
|
|
|
if not self.running_batch.is_empty:
|
|
|
|
for seq in self.running_batch.sequences_set:
|
|
|
|
recycle = self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
|
|
|
|
if recycle:
|
|
|
|
seq.recycle()
|
|
|
|
self.running_batch.del_seq(seq)
|
|
|
|
self.running_list.remove(seq)
|
|
|
|
self.waiting_list[-1].append(seq)
|
|
|
|
# the recycled sequences are handled with highest priority.
|
|
|
|
|
|
|
|
return self.running_batch
|
|
|
|
|
|
|
|
def add_sequence(self, req: Sequence):
|
|
|
|
"""
|
|
|
|
Add the request to waiting list.
|
|
|
|
"""
|
|
|
|
assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists."
|
|
|
|
assert (
|
|
|
|
req.input_len <= self.inference_config.max_input_len
|
|
|
|
), f"Sequence {req.request_id} exceeds input length limit"
|
|
|
|
self.waiting_list[req.input_len * 3 // (self.inference_config.max_input_len + 1)].append(req)
|
|
|
|
|
|
|
|
def abort_sequence(self, request_id: str):
|
|
|
|
"""
|
|
|
|
Abort the request.
|
|
|
|
"""
|
|
|
|
seq, priority = self._find_sequence(request_id)
|
|
|
|
if seq.status == RequestStatus.WAITING:
|
|
|
|
seq.mark_aborted()
|
|
|
|
self.waiting_list[priority].remove(seq)
|
|
|
|
elif seq.status.is_running():
|
|
|
|
self.cache_manager.free_block_table(seq.block_table)
|
|
|
|
self.running_list.remove(seq)
|
|
|
|
else:
|
|
|
|
try:
|
|
|
|
self.done_list.remove(seq)
|
|
|
|
except:
|
|
|
|
return
|
|
|
|
|
|
|
|
def _find_sequence(self, request_id: str) -> Sequence:
|
|
|
|
"""
|
|
|
|
Find the request by request_id.
|
|
|
|
"""
|
|
|
|
for priority, lst in enumerate(self.waiting_list):
|
|
|
|
for seq in lst:
|
|
|
|
if seq.request_id == request_id:
|
|
|
|
return seq, priority
|
|
|
|
|
|
|
|
if self.running_list.find_seq(request_id):
|
|
|
|
return seq, None
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config):
|
|
|
|
if generation_config.num_beams == 1:
|
|
|
|
if generation_config.do_sample:
|
|
|
|
sample_tokens = multinomial_sample(generation_config, probs)
|
|
|
|
else:
|
|
|
|
sample_tokens = greedy_sample(generation_config, logprobs)
|
|
|
|
else:
|
|
|
|
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_batch.is_empty)
|
|
|
|
|
|
|
|
return sample_tokens
|
|
|
|
|
|
|
|
def mark_finished(self, sequence: Sequence, generation_config):
|
|
|
|
if (
|
|
|
|
sequence.output_token_id[-1] == generation_config.eos_id
|
|
|
|
or sequence.output_len >= generation_config.max_output_len
|
|
|
|
):
|
|
|
|
sequence.mark_finished()
|
|
|
|
|
|
|
|
def check_unfinished_seqs(self) -> bool:
|
|
|
|
return self._has_waiting() or not self.running_list.is_empty()
|
|
|
|
|
|
|
|
def search_tokens(self, generation_config, logits):
|
|
|
|
"""
|
|
|
|
Sample tokens for finished requests.
|
|
|
|
"""
|
|
|
|
# do logit processor
|
|
|
|
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
|
|
|
for type in ["top_k", "top_p", "min_p"]:
|
|
|
|
config_dict = generation_config.to_dict()
|
|
|
|
if type in config_dict and config_dict[type] is not None:
|
|
|
|
logits = logit_processor(type, logits, config_dict[type])
|
|
|
|
|
|
|
|
# calculate probs
|
|
|
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
|
|
|
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
|
|
|
|
|
|
|
# sample the next tokens
|
|
|
|
sample_tokens = self._sample(probs, logprobs, generation_config)
|
|
|
|
if not self.prefill_batch.is_empty:
|
|
|
|
self.prefill_batch.update_batch_tokens(sample_tokens)
|
|
|
|
else:
|
|
|
|
self.running_batch.update_batch_tokens(sample_tokens)
|
|
|
|
|
|
|
|
def update(self):
|
|
|
|
"""
|
|
|
|
Update current running list and done list
|
|
|
|
"""
|
|
|
|
if not self.prefill_batch.is_empty:
|
|
|
|
self.running_list.decoding.extend(self.running_list.prefill)
|
|
|
|
self.running_batch.add_seqs(self.running_list.prefill)
|
|
|
|
self.running_list.prefill.clear()
|
|
|
|
self.prefill_batch.clear_batch()
|
|
|
|
|
|
|
|
finish_seqs = self.running_batch.fliter_batch()
|
|
|
|
|
|
|
|
for seq in finish_seqs:
|
|
|
|
self.running_list.remove(seq)
|
|
|
|
self.cache_manager.free_block_table(seq.block_table)
|
|
|
|
|
|
|
|
self.done_list.extend(finish_seqs)
|
|
|
|
|
|
|
|
return finish_seqs
|