mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
776 lines
34 KiB
776 lines
34 KiB
import time
|
|
from itertools import count
|
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import distributed as dist
|
|
from transformers import (
|
|
AutoConfig,
|
|
AutoModelForCausalLM,
|
|
GenerationConfig,
|
|
PreTrainedTokenizer,
|
|
PreTrainedTokenizerFast,
|
|
)
|
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
|
|
|
from colossalai.accelerator import get_accelerator
|
|
from colossalai.cluster import ProcessGroupMesh
|
|
from colossalai.inference.batch_bucket import BatchBucket
|
|
from colossalai.inference.config import InferenceConfig, InputMetaData
|
|
from colossalai.inference.graph_runner import CUDAGraphRunner
|
|
from colossalai.inference.modeling.policy import model_policy_map
|
|
from colossalai.inference.sampler import search_tokens
|
|
from colossalai.inference.spec import Drafter, GlideInput
|
|
from colossalai.inference.struct import Sequence
|
|
from colossalai.inference.utils import get_model_size
|
|
from colossalai.interface import ModelWrapper
|
|
from colossalai.logging import get_dist_logger
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
|
from colossalai.shardformer.policies.base_policy import Policy
|
|
|
|
from .request_handler import RequestHandler
|
|
|
|
__all__ = ["InferenceEngine"]
|
|
|
|
PP_AXIS, TP_AXIS = 0, 1
|
|
|
|
_supported_models = {
|
|
"LlamaForCausalLM": LlamaForCausalLM,
|
|
"BaichuanForCausalLM": AutoModelForCausalLM,
|
|
}
|
|
|
|
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
|
|
|
|
|
class InferenceEngine:
|
|
|
|
"""
|
|
InferenceEngine which manages the inference process..
|
|
|
|
Args:
|
|
model_or_path (nn.Module or str): Path or nn.Module of this model.
|
|
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
|
|
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
|
verbose (bool): Determine whether or not to log the generation process.
|
|
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_or_path: Union[nn.Module, str],
|
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
|
inference_config: InferenceConfig,
|
|
verbose: bool = False,
|
|
model_policy: Union[Policy, Type[Policy]] = None,
|
|
) -> None:
|
|
self.inference_config = inference_config
|
|
self.dtype = inference_config.dtype
|
|
self.high_precision = inference_config.high_precision
|
|
|
|
self.verbose = verbose
|
|
self.logger = get_dist_logger(__name__)
|
|
|
|
self.init_model(model_or_path, model_policy)
|
|
|
|
self.generation_config = inference_config.to_generation_config(self.model_config)
|
|
self.generation_config_dict = self.generation_config.to_dict()
|
|
|
|
self.tokenizer = tokenizer
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
|
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
|
|
# DISCUSS maybe move this into batch info?
|
|
|
|
self.counter = count()
|
|
|
|
self.use_cuda_graph = self.inference_config.use_cuda_graph
|
|
if self.use_cuda_graph:
|
|
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
|
self.graph_memory_pool = None # Set during graph capture.
|
|
if verbose:
|
|
self.logger.info("Colossal AI CUDA Graph Capture on")
|
|
|
|
self.capture_model(self.k_cache, self.v_cache)
|
|
|
|
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
|
self.use_spec_dec = False
|
|
self.drafter_model = None
|
|
self.drafter = None
|
|
self.use_glide = False
|
|
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
|
|
|
self._verify_args()
|
|
|
|
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
|
|
"""
|
|
Shard model or/and Load weight
|
|
|
|
Args:
|
|
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
|
model_policy (Policy): the policy to replace the model
|
|
"""
|
|
|
|
if isinstance(model_or_path, str):
|
|
try:
|
|
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
|
arch = getattr(hf_config, "architectures")[0]
|
|
if arch in _supported_models.keys():
|
|
# NOTE(lry89757) Currently we load the model using transformers-api,
|
|
# but we will use lazy tensor and checkpoint io to accelerate
|
|
# the model load process in the future.
|
|
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
|
else:
|
|
raise ValueError(f"Model {arch} is not supported.")
|
|
|
|
except Exception as e:
|
|
self.logger.error(
|
|
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
|
)
|
|
else:
|
|
model = model_or_path
|
|
|
|
self.model_config = model.config
|
|
|
|
torch.cuda.empty_cache()
|
|
init_gpu_memory = torch.cuda.mem_get_info()[0]
|
|
|
|
self.device = get_accelerator().get_current_device()
|
|
if self.verbose:
|
|
self.logger.info(f"the device is {self.device}")
|
|
|
|
model = model.to(self.dtype).eval()
|
|
|
|
if self.verbose:
|
|
self.logger.info(
|
|
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
|
|
)
|
|
|
|
if model_policy is None:
|
|
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
|
|
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
|
|
model_policy = model_policy_map.get(model_policy_key)
|
|
|
|
if not isinstance(model_policy, Policy):
|
|
try:
|
|
model_policy = model_policy()
|
|
except Exception as e:
|
|
raise ValueError(f"Unable to instantiate model policy: {e}")
|
|
|
|
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
|
|
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
|
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
|
|
|
self.model = self._shardformer(
|
|
model,
|
|
model_policy,
|
|
None,
|
|
tp_group=tp_group,
|
|
)
|
|
|
|
self.model = ModelWrapper(model).to(self.device)
|
|
|
|
if self.verbose:
|
|
self.logger.info(
|
|
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
|
)
|
|
|
|
# NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor
|
|
# if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM):
|
|
# from colossalai.inference.core.plugin import InferCheckpoint_io
|
|
|
|
# cpt_io = InferCheckpoint_io()
|
|
# if_has_index_file, model_index_file = has_index_file(model_or_path)
|
|
# assert if_has_index_file, "the model path is invalid"
|
|
# cpt_io.load_model(self.model, model_index_file)
|
|
|
|
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
|
peak_memory = init_gpu_memory - free_gpu_memory
|
|
if self.verbose:
|
|
self.logger.info(
|
|
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
|
|
assert self.use_cuda_graph, "please turn on the cuda graph"
|
|
|
|
if self.verbose:
|
|
self.logger.info("Colossal AI CUDA Graph Capture begin")
|
|
|
|
t_capture_begin = time.perf_counter()
|
|
|
|
block_size = self.inference_config.block_size
|
|
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
|
|
|
|
# Prepare dummy inputs. These will be reused for all batch sizes.
|
|
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
|
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
|
|
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
|
|
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
|
# self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
|
|
self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
|
|
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
|
|
self.graph_block_tables[0, :] = np.arange(
|
|
0, max_num_blocks
|
|
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
|
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
|
output_tensor = torch.zeros(
|
|
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
|
|
)
|
|
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
|
|
|
|
max_num_seqs = self.inference_config.max_batch_size
|
|
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
|
|
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
|
|
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
|
sequence_lengths[0] = torch.tensor(
|
|
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
|
|
).cuda()
|
|
|
|
# NOTE: Capturing the largest batch size first may help reduce the
|
|
# memory usage of CUDA graph.
|
|
for batch_size in reversed(batch_size_capture_list):
|
|
if self.verbose:
|
|
self.logger.info(f"batch size {batch_size} graph capturing")
|
|
|
|
input_meta_data = InputMetaData(
|
|
block_tables=block_tables[:batch_size],
|
|
sequence_lengths=sequence_lengths[:batch_size],
|
|
fd_inter_tensor=fd_inter_tensor,
|
|
batch_size=batch_size,
|
|
is_prompts=False,
|
|
use_cuda_graph=True,
|
|
high_precision=False,
|
|
kv_seq_len=sequence_lengths[:batch_size].max().item(),
|
|
head_dim=head_dim,
|
|
dtype=self.dtype,
|
|
)
|
|
|
|
graph_runner = CUDAGraphRunner(self.model)
|
|
graph_runner.capture(
|
|
input_tokens_ids[:batch_size],
|
|
output_tensor[:batch_size],
|
|
input_meta_data,
|
|
k_caches=k_cache,
|
|
v_caches=v_cache,
|
|
memory_pool=self.graph_memory_pool,
|
|
)
|
|
self.graph_memory_pool = graph_runner.graph.pool()
|
|
self.graph_runners[batch_size] = graph_runner
|
|
|
|
t_capture_end = time.perf_counter()
|
|
|
|
if self.verbose:
|
|
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
|
|
|
|
def _verify_args(self) -> None:
|
|
"""Verify the input args"""
|
|
if not isinstance(self.inference_config, InferenceConfig):
|
|
raise TypeError("Invalid type of inference config provided.")
|
|
if not isinstance(self.model, nn.Module):
|
|
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
|
|
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
|
|
raise TypeError(
|
|
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
|
)
|
|
if isinstance(self.model, ModelWrapper):
|
|
model = self.model.module
|
|
assert (
|
|
model.__class__.__name__ in _supported_models.keys()
|
|
), f"Model {self.model.__class__.__name__} is not supported."
|
|
|
|
def _shardformer(
|
|
self,
|
|
model: nn.Module,
|
|
model_policy: Policy,
|
|
stage_manager: PipelineStageManager = None,
|
|
tp_group: ProcessGroupMesh = None,
|
|
) -> nn.Module:
|
|
"""
|
|
Initialize ShardConfig and replace the model with shardformer.
|
|
|
|
Args:
|
|
model (nn.Module): Path or nn.Module of this model.
|
|
model_policy (Policy): The policy to shardformer model which is determined by the model type.
|
|
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
|
|
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
|
|
|
|
Returns:
|
|
nn.Module: The model optimized by Shardformer.
|
|
"""
|
|
|
|
shardconfig = ShardConfig(
|
|
tensor_parallel_process_group=tp_group,
|
|
pipeline_stage_manager=stage_manager,
|
|
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
|
|
enable_fused_normalization=False,
|
|
enable_all_optimization=False,
|
|
enable_flash_attention=False,
|
|
enable_jit_fused=False,
|
|
enable_sequence_parallelism=False,
|
|
)
|
|
shardformer = ShardFormer(shard_config=shardconfig)
|
|
shard_model, _ = shardformer.optimize(model, model_policy)
|
|
return shard_model
|
|
|
|
def enable_spec_dec(
|
|
self,
|
|
drafter_model: nn.Module = None,
|
|
n_spec_tokens: int = None,
|
|
use_glide_drafter: bool = False,
|
|
) -> None:
|
|
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
|
|
|
|
Args:
|
|
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
|
|
If provided, the previous drafter and drafter model, if exist, will be overwritten.
|
|
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
|
|
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
|
|
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
|
|
If True, the drafter model will be replaced by a glide model.
|
|
|
|
```python
|
|
...
|
|
engine = InferenceEngine(model, tokenizer, inference_config)
|
|
|
|
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
|
|
engine.generate(...) # Speculative Decoding
|
|
|
|
engine.disable_spec_dec()
|
|
engine.generate(...) # Normal generation
|
|
|
|
engine.enable_spec_dec()
|
|
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
|
|
engine.clear_spec_dec()
|
|
```
|
|
"""
|
|
if drafter_model is None and self.drafter is None:
|
|
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
|
|
if n_spec_tokens is not None:
|
|
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
|
|
self.n_spec_tokens = n_spec_tokens
|
|
if drafter_model is not None:
|
|
assert isinstance(drafter_model, nn.Module)
|
|
# overwrite the drafter, if exists
|
|
self.clear_spec_dec()
|
|
self.drafter_model = drafter_model
|
|
self.drafter = Drafter(
|
|
self.drafter_model,
|
|
self.tokenizer,
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
)
|
|
|
|
# check if the provided drafter model is compatible with GLIDE structure
|
|
# when `use_glide_drafter` is set to True
|
|
if (
|
|
use_glide_drafter
|
|
and hasattr(drafter_model, "model")
|
|
and hasattr(drafter_model.model, "layers")
|
|
and hasattr(drafter_model.model.layers[0], "cross_attn")
|
|
):
|
|
self.use_glide = use_glide_drafter
|
|
elif use_glide_drafter:
|
|
self.logger.warning(
|
|
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
|
|
f"but the provided drafter model is not compatible with GLIDE structure."
|
|
f"Falling back to use the default drafter model (non-GLIDE)."
|
|
)
|
|
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
|
|
# using speculative decoding for subsequent generations
|
|
self.use_spec_dec = True
|
|
|
|
def disable_spec_dec(self) -> None:
|
|
"""Disable using speculative decoding for subsequent generations."""
|
|
self.request_handler.unset_spec_dec_mode()
|
|
# set back to the maximum number of tokens to speculate
|
|
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
|
self.use_glide = False
|
|
self.use_spec_dec = False
|
|
|
|
def clear_spec_dec(self) -> None:
|
|
"""Clear relatable structures of speculative decoding, if exist."""
|
|
if self.use_spec_dec:
|
|
self.disable_spec_dec()
|
|
if self.drafter_model or self.drafter:
|
|
self.drafter_model = None
|
|
self.drafter = None
|
|
torch.cuda.empty_cache()
|
|
self.use_glide = False
|
|
self.use_spec_dec = False
|
|
|
|
def steps_spec_dec(self) -> List[Sequence]:
|
|
"""
|
|
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
|
|
with many steps of speculating by a drafter model as well as verifying by a main model.
|
|
|
|
Returns:
|
|
List[Sequence]: finished sequences generated by one step.
|
|
"""
|
|
batch = self.request_handler.schedule() # prefill batch
|
|
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
|
|
|
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
|
|
|
if input_meta_data.use_cuda_graph:
|
|
model_executable = self.graph_runners[input_meta_data.batch_size]
|
|
else:
|
|
model_executable = self.model
|
|
|
|
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
|
|
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
|
|
drafter_out = self.drafter.speculate(input_token_ids, 1, None)
|
|
next_token_ids_spec = drafter_out.next_tokens
|
|
drafter_past_key_values = drafter_out.past_key_values
|
|
|
|
# 2. Prefill main model (Verifier) - fill past kv cache for main model
|
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
|
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
|
# append new inputs to the batch, temporarily
|
|
batch.append_batch_tokens(next_tokens)
|
|
self.request_handler.allocate_batch_spec_dec(batch, 1)
|
|
already_allocated_kv_len = batch.seq_lengths[0].item()
|
|
input_token_ids = batch.get_1D_inputs_spec_dec(1)
|
|
|
|
finished_sequences = self.request_handler.update()
|
|
|
|
while True:
|
|
# HACK Retrieve the running batch
|
|
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
|
|
batch = self.request_handler.running_bb # running batch
|
|
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
|
|
|
# 3. Decoding - Drafter model speculates `n` tokens
|
|
glide_input = None
|
|
if self.use_glide:
|
|
glide_input = GlideInput(
|
|
batch.get_block_table_tensor(),
|
|
self.k_cache[-1], # use kv cahces of the last layer
|
|
self.v_cache[-1],
|
|
batch.get_sequence_lengths(),
|
|
)
|
|
|
|
drafter_out = self.drafter.speculate(
|
|
input_token_ids,
|
|
self.n_spec_tokens,
|
|
drafter_past_key_values,
|
|
glide_input=glide_input,
|
|
)
|
|
next_token_ids_spec = drafter_out.next_tokens
|
|
drafter_past_key_values = drafter_out.past_key_values
|
|
drafter_spec_length = drafter_out.speculated_length
|
|
|
|
for next_token_id_spec in next_token_ids_spec:
|
|
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
|
|
cur_length = batch.seq_lengths[0].item()
|
|
if already_allocated_kv_len < cur_length:
|
|
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
|
|
already_allocated_kv_len = cur_length
|
|
|
|
# 4. Decoding - Main model verifies `n` tokens in parallel
|
|
if drafter_spec_length < batch.num_tokens_to_verify:
|
|
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
|
|
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
|
|
|
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
|
|
|
# 5. Compare and process the results
|
|
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
|
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
|
|
|
|
# revoke appended tokens for each Sequence in the current batch
|
|
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
|
|
|
|
# append the last correct token generated by the main model
|
|
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
|
|
|
|
# trim past key values of the drafter model
|
|
drafter_past_key_values = Drafter.trim_kv_cache(
|
|
drafter_past_key_values, drafter_spec_length - n_matches - 1
|
|
)
|
|
|
|
# prepare inputs for the next round of speculation
|
|
n = 1 if n_matches < drafter_spec_length else 2
|
|
input_token_ids = batch.get_1D_inputs_spec_dec(n)
|
|
|
|
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
|
|
finished_sequences = self.request_handler.update()
|
|
if len(finished_sequences) > 0:
|
|
break
|
|
|
|
# Reset back the number of speculated tokens of the batch,
|
|
# this is used to handle the last round of speculation, in which case the number of speculated tokens
|
|
# by the drafter is less than the number of speculated tokens set to the engine.
|
|
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
|
|
|
|
return finished_sequences
|
|
|
|
def generate(
|
|
self,
|
|
request_ids: Union[List[int], int] = None,
|
|
prompts: Union[List[str], str] = None,
|
|
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
|
return_token_ids: bool = False,
|
|
generation_config: Optional[GenerationConfig] = None,
|
|
) -> List[str]:
|
|
"""
|
|
Executing the inference step.
|
|
|
|
Args:
|
|
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
|
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
|
request_ids (List[int], optional): The request ID. Defaults to None.
|
|
return_token_ids (bool): Whether to return output token ids. Defaults to False.
|
|
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
|
|
|
Returns:
|
|
List[str]: Inference result returned by one generation.
|
|
"""
|
|
|
|
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
|
prompts = [prompts] if isinstance(prompts, str) else prompts
|
|
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
|
|
|
|
with torch.inference_mode():
|
|
if prompts is not None or prompts_token_ids is not None:
|
|
self.add_request(
|
|
request_ids=request_ids,
|
|
prompts=prompts,
|
|
prompts_token_ids=prompts_token_ids,
|
|
**gen_config_dict,
|
|
)
|
|
|
|
output_seqs_list = []
|
|
total_tokens_list = []
|
|
|
|
# intuition: If user provide a generation config, we should replace the existing one.
|
|
if generation_config is not None:
|
|
self.generation_config = generation_config
|
|
self.generation_config_dict = gen_config_dict
|
|
|
|
if self.use_spec_dec:
|
|
assert self.drafter is not None, "Drafter Model is not initialized."
|
|
while self.request_handler.check_unfinished_seqs():
|
|
output_seqs_list += self.steps_spec_dec()
|
|
else:
|
|
while self.request_handler.check_unfinished_seqs():
|
|
output_seqs_list += self.step()
|
|
|
|
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
|
|
|
|
for seq in output_seqs_list:
|
|
total_tokens_list.append(seq.input_token_id + seq.output_token_id)
|
|
|
|
output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
|
|
|
|
if return_token_ids:
|
|
output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
|
|
return output_str, output_tokens_list
|
|
else:
|
|
return output_str
|
|
|
|
@property
|
|
def has_prompt_template(self) -> bool:
|
|
""" """
|
|
return self.inference_config.prompt_template is not None
|
|
|
|
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
|
|
"""
|
|
This method will format the input prompt according to the prompt template given to the InferenceConfig.
|
|
"""
|
|
assert (
|
|
self.has_prompt_template
|
|
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
|
|
|
|
if isinstance(prompts, (list, tuple)):
|
|
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
|
|
elif isinstance(prompts, str):
|
|
return self.inference_config.prompt_template.format(input_text=prompts)
|
|
else:
|
|
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
|
|
|
|
def add_request(
|
|
self,
|
|
request_ids: Union[List[int], int] = None,
|
|
prompts: Union[List[str], str] = None,
|
|
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
|
**kwargs,
|
|
) -> None:
|
|
"""
|
|
Add requests.
|
|
|
|
Args:
|
|
request_ids (List[int], optional): The request ID. Defaults to None.
|
|
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
|
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
|
"""
|
|
|
|
# apply the prompt template to the input prompts
|
|
|
|
if self.has_prompt_template and prompts is not None:
|
|
prompts = self.format_prompt(prompts)
|
|
|
|
block_size = self.inference_config.block_size
|
|
|
|
if request_ids is not None and not isinstance(request_ids, list):
|
|
request_ids = [request_ids]
|
|
|
|
if prompts is not None and not isinstance(prompts, list):
|
|
prompts = [prompts]
|
|
|
|
if prompts_token_ids is None:
|
|
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
|
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
|
|
"input_ids"
|
|
]
|
|
|
|
# list of torch Tensor
|
|
if isinstance(prompts_token_ids, list):
|
|
if isinstance(prompts_token_ids[0], torch.Tensor):
|
|
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
|
|
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
|
prompts_token_ids = prompts_token_ids.tolist()
|
|
else:
|
|
raise TypeError(
|
|
f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
|
|
)
|
|
|
|
assert (
|
|
len(prompts_token_ids[0]) <= self.inference_config.max_input_len
|
|
), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
|
|
|
|
prompts_num = len(prompts_token_ids)
|
|
|
|
for i in range(prompts_num):
|
|
if request_ids:
|
|
assert isinstance(
|
|
request_ids[0], int
|
|
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
|
assert len(request_ids) == prompts_num
|
|
request_id = request_ids[i]
|
|
else:
|
|
request_id = next(self.counter)
|
|
if prompts == None:
|
|
prompt = None
|
|
else:
|
|
prompt = prompts[i]
|
|
|
|
max_length = kwargs.get("max_length", None)
|
|
max_new_tokens = kwargs.get("max_new_tokens", None)
|
|
if max_length is None and max_new_tokens is None:
|
|
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
|
|
elif max_length is not None:
|
|
max_new_tokens = max_length - len(prompts_token_ids[i])
|
|
|
|
if not self.inference_config.enable_streamingllm:
|
|
assert (
|
|
self.inference_config.max_output_len >= max_new_tokens
|
|
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
|
|
|
|
sequence = Sequence(
|
|
request_id,
|
|
prompt,
|
|
prompts_token_ids[i],
|
|
block_size,
|
|
None,
|
|
self.tokenizer.eos_token_id,
|
|
self.tokenizer.pad_token_id,
|
|
max_output_len=max_new_tokens,
|
|
ignore_eos=self.inference_config.ignore_eos,
|
|
)
|
|
self.request_handler.add_sequence(sequence)
|
|
|
|
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
|
|
input_ids = batch.get_1D_inputs()
|
|
sequence_lengths = batch.get_sequence_lengths()
|
|
|
|
if batch.is_prompts:
|
|
n_tokens = sequence_lengths.sum().item()
|
|
else:
|
|
n_tokens = batch.current_batch_size
|
|
if batch.use_spec_dec:
|
|
n_tokens = batch.num_tokens_to_verify + 1
|
|
assert n_tokens == input_ids.size(0)
|
|
n_tokens = n_tokens * batch.current_batch_size
|
|
output_tensor = torch.zeros(
|
|
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
|
)
|
|
|
|
batch_token_ids = None
|
|
if (
|
|
self.generation_config.repetition_penalty != 1.0
|
|
or self.generation_config.no_repeat_ngram_size > 0
|
|
or self.generation_config.forced_eos_token_id is not None
|
|
):
|
|
batch_token_ids = batch.batch_token_ids
|
|
|
|
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
|
use_cuda_graph = False
|
|
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
|
|
use_cuda_graph = True
|
|
|
|
input_meta_data = InputMetaData(
|
|
block_tables=batch.get_block_table_tensor(),
|
|
sequence_lengths=sequence_lengths,
|
|
fd_inter_tensor=batch.fd_inter_tensor,
|
|
batch_size=batch.current_batch_size,
|
|
is_prompts=batch.is_prompts,
|
|
use_cuda_kernel=self.inference_config.use_cuda_kernel,
|
|
use_cuda_graph=use_cuda_graph,
|
|
high_precision=self.high_precision,
|
|
kv_seq_len=sequence_lengths.max().item(),
|
|
head_dim=batch.head_dim,
|
|
dtype=batch.dtype,
|
|
use_spec_dec=batch.use_spec_dec,
|
|
num_tokens_to_verify=batch.num_tokens_to_verify,
|
|
batch_token_ids=batch_token_ids,
|
|
)
|
|
|
|
return input_ids, output_tensor, input_meta_data
|
|
|
|
def step(self) -> List[str]:
|
|
"""
|
|
In each step, do the follows:
|
|
1. Run RequestHandler.schedule() and get the batch used for inference.
|
|
2. Get the input, inputinfo and output placeholder from the batchbucket
|
|
3. Run model to generate the next token
|
|
4. Update waiting list and running list in RequestHandler and get finished sequences.
|
|
5. Decode and return finished sequences.
|
|
|
|
Returns:
|
|
List[str]: Decoded finished sequences generated by one step.
|
|
"""
|
|
|
|
batch = self.request_handler.schedule()
|
|
|
|
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
|
|
|
if input_meta_data.use_cuda_graph:
|
|
model_executable = self.graph_runners[input_meta_data.batch_size]
|
|
else:
|
|
model_executable = self.model
|
|
|
|
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
|
if self.inference_config.pad_input:
|
|
logits = logits[:, -1, :]
|
|
|
|
if self.inference_config.enable_streamingllm:
|
|
updated_block_ids = batch.streamingllm_update_batch(
|
|
self.inference_config.start_token_size, self.inference_config.generated_token_size
|
|
)
|
|
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
|
|
|
|
next_tokens = search_tokens(
|
|
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
|
|
)
|
|
self.request_handler.append_next_tokens(next_tokens)
|
|
finished_sequences = self.request_handler.update()
|
|
|
|
return finished_sequences
|