import time from itertools import count from typing import Dict, List, Optional, Tuple, Type, Union import numpy as np import torch import torch.nn as nn from torch import distributed as dist from transformers import ( AutoConfig, AutoModelForCausalLM, GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast, ) from transformers.models.llama.modeling_llama import LlamaForCausalLM from colossalai.accelerator import get_accelerator from colossalai.cluster import ProcessGroupMesh from colossalai.inference.batch_bucket import BatchBucket from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig from colossalai.inference.graph_runner import CUDAGraphRunner from colossalai.inference.modeling.policy import model_policy_map from colossalai.inference.sampler import search_tokens from colossalai.inference.spec import Drafter, GlideInput from colossalai.inference.struct import Sequence from colossalai.inference.utils import get_model_size, has_index_file from colossalai.interface import ModelWrapper from colossalai.lazy import LazyInitContext from colossalai.logging import get_dist_logger from colossalai.shardformer.policies.base_policy import Policy from .base_engine import BaseEngine from .request_handler import RequestHandler PP_AXIS, TP_AXIS = 0, 1 _supported_models = { "LlamaForCausalLM": LlamaForCausalLM, "BaichuanForCausalLM": AutoModelForCausalLM, } _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] class LLMEngine(BaseEngine): """ InferenceEngine which manages the inference process.. Args: model_or_path (nn.Module or str): Path or nn.Module of this model. tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use. inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. verbose (bool): Determine whether or not to log the generation process. model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided. """ def __init__( self, model_or_path: Union[nn.Module, str], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None, inference_config: InferenceConfig = None, verbose: bool = False, model_policy: Union[Policy, type[Policy]] = None, ) -> None: self.inference_config = inference_config self.dtype = inference_config.dtype self.high_precision = inference_config.high_precision self.verbose = verbose self.logger = get_dist_logger(__name__) self.model_shard_infer_config = inference_config.to_model_shard_inference_config() self.init_model(model_or_path, model_policy, self.model_shard_infer_config) self.generation_config = inference_config.to_generation_config(self.model_config) self.generation_config_dict = self.generation_config.to_dict() self.tokenizer = tokenizer self.tokenizer.pad_token = self.tokenizer.eos_token self.request_handler = RequestHandler(self.inference_config, self.model_config) self.k_cache, self.v_cache = self.request_handler.get_kvcache() # DISCUSS maybe move this into batch info? self.counter = count() self.use_cuda_graph = self.inference_config.use_cuda_graph if self.use_cuda_graph: self.graph_runners: Dict[int, CUDAGraphRunner] = {} self.graph_memory_pool = None # Set during graph capture. if verbose: self.logger.info("Colossal AI CUDA Graph Capture on") self.capture_model(self.k_cache, self.v_cache) # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` self.use_spec_dec = self.inference_config.use_spec_dec self.drafter_model = None self.drafter = None self.use_glide = False self.n_spec_tokens = self.inference_config.max_n_spec_tokens self._verify_args() def init_model( self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None, model_shard_infer_config: ModelShardInferenceConfig = None, ): """ Shard model or/and Load weight Args: model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. model_policy (Policy): the policy to replace the model. model_inference_config: the configuration for modeling initialization when inference. model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. """ pretrained_path = None if isinstance(model_or_path, str): import colossalai.interface.pretrained as pretrained_utils try: hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype) arch = getattr(hf_config, "architectures")[0] if arch in _supported_models.keys(): if arch == "BaichuanForCausalLM": self.logger.warning( "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers" ) ctx = LazyInitContext(default_device="cuda") with ctx: model = _supported_models[arch].from_pretrained( model_or_path, trust_remote_code=True, torch_dtype=self.dtype ) pretrained_path = pretrained_utils.get_pretrained_path(model) else: # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate raise ValueError(f"Model {arch} is not supported.") except Exception as e: self.logger.error( f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" ) else: model = model_or_path self.model_config = model.config torch.cuda.empty_cache() init_gpu_memory = torch.cuda.mem_get_info()[0] self.device = get_accelerator().get_current_device() if self.verbose: self.logger.info(f"the device is {self.device}") model = model.to(self.dtype).eval() if self.verbose: self.logger.info( f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" ) if model_policy is None: prefix = "nopadding" if not self.inference_config.pad_input else "padding" model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}" model_policy = model_policy_map.get(model_policy_key) if not isinstance(model_policy, Policy): try: model_policy = model_policy() except Exception as e: raise ValueError(f"Unable to instantiate model policy: {e}") assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}" pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) tp_group = pg_mesh.get_group_along_axis(TP_AXIS) self.model = self._shardformer( model, model_policy, model_shard_infer_config, None, tp_group=tp_group, ) self.model = ModelWrapper(model).to(self.device) if self.verbose: self.logger.info( f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" ) if pretrained_path: from colossalai.inference.core.plugin import InferCheckpoint_io cpt_io = InferCheckpoint_io() if_has_index_file, model_index_file = has_index_file(pretrained_path) assert if_has_index_file, "the model path is invalid" cpt_io.load_model(self.model, model_index_file) free_gpu_memory, _ = torch.cuda.mem_get_info() peak_memory = init_gpu_memory - free_gpu_memory if self.verbose: self.logger.info( f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" ) @torch.inference_mode() def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]): assert self.use_cuda_graph, "please turn on the cuda graph" if self.verbose: self.logger.info("Colossal AI CUDA Graph Capture begin") t_capture_begin = time.perf_counter() block_size = self.inference_config.block_size head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads # Prepare dummy inputs. These will be reused for all batch sizes. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) max_context_len_to_capture = self.inference_config.max_context_len_to_capture max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda() # self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32) self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE)) self.graph_block_tables[0, :] = np.arange( 0, max_num_blocks ) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len block_tables = torch.from_numpy(self.graph_block_tables).cuda() output_tensor = torch.zeros( (max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device ) fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor max_num_seqs = self.inference_config.max_batch_size batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs] sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda() # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len sequence_lengths[0] = torch.tensor( self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32 ).cuda() # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): if self.verbose: self.logger.info(f"batch size {batch_size} graph capturing") input_meta_data = InputMetaData( block_tables=block_tables[:batch_size], sequence_lengths=sequence_lengths[:batch_size], fd_inter_tensor=fd_inter_tensor, batch_size=batch_size, is_prompts=False, use_cuda_graph=True, high_precision=False, kv_seq_len=sequence_lengths[:batch_size].max().item(), head_dim=head_dim, dtype=self.dtype, ) graph_runner = CUDAGraphRunner(self.model) graph_runner.capture( input_tokens_ids[:batch_size], output_tensor[:batch_size], input_meta_data, k_caches=k_cache, v_caches=v_cache, memory_pool=self.graph_memory_pool, ) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[batch_size] = graph_runner t_capture_end = time.perf_counter() if self.verbose: self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s") def _verify_args(self) -> None: """Verify the input args""" if not isinstance(self.inference_config, InferenceConfig): raise TypeError("Invalid type of inference config provided.") if not isinstance(self.model, nn.Module): raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}") if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)): raise TypeError( f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}" ) if isinstance(self.model, ModelWrapper): model = self.model.module assert ( model.__class__.__name__ in _supported_models.keys() ), f"Model {self.model.__class__.__name__} is not supported." def enable_spec_dec( self, drafter_model: nn.Module = None, n_spec_tokens: int = None, use_glide_drafter: bool = False, ) -> None: """Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations. Args: drafter_model (nn.Module): The drafter model (small model) used to speculate tokens. If provided, the previous drafter and drafter model, if exist, will be overwritten. n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying. If not provided, `max_n_spec_tokens` in InferenceConfig will be used. use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False. If True, the drafter model will be replaced by a glide model. ```python ... engine = InferenceEngine(model, tokenizer, inference_config) engine.enable_spec_dec(drafter_model, n_spec_tokens=5) engine.generate(...) # Speculative Decoding engine.disable_spec_dec() engine.generate(...) # Normal generation engine.enable_spec_dec() engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens engine.clear_spec_dec() ``` """ if drafter_model is None and self.drafter is None: raise ValueError("Drafter not initialized. Please provide a Drafter Model") if n_spec_tokens is not None: assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens self.n_spec_tokens = n_spec_tokens if drafter_model is not None: assert isinstance(drafter_model, nn.Module) # overwrite the drafter, if exists self.clear_spec_dec() self.drafter_model = drafter_model self.drafter = Drafter( self.drafter_model, self.tokenizer, device=self.device, dtype=self.dtype, ) # check if the provided drafter model is compatible with GLIDE structure # when `use_glide_drafter` is set to True if ( use_glide_drafter and hasattr(drafter_model, "model") and hasattr(drafter_model.model, "layers") and hasattr(drafter_model.model.layers[0], "cross_attn") ): self.use_glide = use_glide_drafter elif use_glide_drafter: self.logger.warning( f"`use_glide_drafter` is provided as {use_glide_drafter}, " f"but the provided drafter model is not compatible with GLIDE structure." f"Falling back to use the default drafter model (non-GLIDE)." ) self.request_handler.set_spec_dec_mode(self.n_spec_tokens) # using speculative decoding for subsequent generations self.use_spec_dec = True def disable_spec_dec(self) -> None: """Disable using speculative decoding for subsequent generations.""" self.request_handler.unset_spec_dec_mode() # set back to the maximum number of tokens to speculate self.n_spec_tokens = self.inference_config.max_n_spec_tokens self.use_glide = False self.use_spec_dec = False def clear_spec_dec(self) -> None: """Clear relatable structures of speculative decoding, if exist.""" if self.use_spec_dec: self.disable_spec_dec() if self.drafter_model or self.drafter: self.drafter_model = None self.drafter = None torch.cuda.empty_cache() self.use_glide = False self.use_spec_dec = False def steps_spec_dec(self) -> List[Sequence]: """ Run Speculative Decoding steps. This is like retrieving a single batch and launch inference with many steps of speculating by a drafter model as well as verifying by a main model. Returns: List[Sequence]: finished sequences generated by one step. """ batch = self.request_handler.schedule() # prefill batch assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) if input_meta_data.use_cuda_graph: model_executable = self.graph_runners[input_meta_data.batch_size] else: model_executable = self.model # 1. Prefill small model (Drafter) - fill past kv cache for drafter model # NOTE For glide drafter models, we won't actually apply glide during prefill stage drafter_out = self.drafter.speculate(input_token_ids, 1, None) next_token_ids_spec = drafter_out.next_tokens drafter_past_key_values = drafter_out.past_key_values # 2. Prefill main model (Verifier) - fill past kv cache for main model logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) # append new inputs to the batch, temporarily batch.append_batch_tokens(next_tokens) self.request_handler.allocate_batch_spec_dec(batch, 1) already_allocated_kv_len = batch.seq_lengths[0].item() input_token_ids = batch.get_1D_inputs_spec_dec(1) finished_sequences = self.request_handler.update() while True: # HACK Retrieve the running batch # Using RequestHandler.schedule here will re-allocate same kv cache for the batch batch = self.request_handler.running_bb # running batch assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." # 3. Decoding - Drafter model speculates `n` tokens glide_input = None if self.use_glide: glide_input = GlideInput( batch.get_block_table_tensor(), self.k_cache[-1], # use kv cahces of the last layer self.v_cache[-1], batch.get_sequence_lengths(), n_spec_tokens=self.n_spec_tokens, ) drafter_out = self.drafter.speculate( input_token_ids, self.n_spec_tokens, drafter_past_key_values, glide_input=glide_input, ) next_token_ids_spec = drafter_out.next_tokens drafter_past_key_values = drafter_out.past_key_values drafter_spec_length = drafter_out.speculated_length for next_token_id_spec in next_token_ids_spec: self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0)) cur_length = batch.seq_lengths[0].item() if already_allocated_kv_len < cur_length: self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len) already_allocated_kv_len = cur_length # 4. Decoding - Main model verifies `n` tokens in parallel if drafter_spec_length < batch.num_tokens_to_verify: batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length) input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) # 5. Compare and process the results diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec)) n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item() # revoke appended tokens for each Sequence in the current batch batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens # append the last correct token generated by the main model self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0)) # trim past key values of the drafter model drafter_past_key_values = Drafter.trim_kv_cache( drafter_past_key_values, drafter_spec_length - n_matches - 1 ) # prepare inputs for the next round of speculation n = 1 if n_matches < drafter_spec_length else 2 input_token_ids = batch.get_1D_inputs_spec_dec(n) self.request_handler.update_batch_finished(batch, generation_config=self.generation_config) finished_sequences = self.request_handler.update() if len(finished_sequences) > 0: break # Reset back the number of speculated tokens of the batch, # this is used to handle the last round of speculation, in which case the number of speculated tokens # by the drafter is less than the number of speculated tokens set to the engine. batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens) return finished_sequences def generate( self, request_ids: Union[List[int], int] = None, prompts: Union[List[str], str] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, return_token_ids: bool = False, generation_config: Optional[GenerationConfig] = None, ) -> Union[List[str], Tuple[List[str], List[List[int]]]]: """ Executing the inference step. Args: request_ids (List[int], optional): The request ID. Defaults to None. prompts (Union[List[str], optional): Input prompts. Defaults to None. prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None. return_token_ids (bool, optional): Whether to return output token ids. Defaults to False. generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None. Returns: Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation. """ gen_config_dict = generation_config.to_dict() if generation_config is not None else {} prompts = [prompts] if isinstance(prompts, str) else prompts request_ids = [request_ids] if isinstance(request_ids, int) else request_ids with torch.inference_mode(): if prompts is not None or prompts_token_ids is not None: self.add_request( request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids, **gen_config_dict, ) output_seqs_list = [] total_tokens_list = [] # intuition: If user provide a generation config, we should replace the existing one. if generation_config is not None: self.generation_config = generation_config self.generation_config_dict = gen_config_dict if self.use_spec_dec: assert self.drafter is not None, "Drafter Model is not initialized." while self.request_handler.check_unfinished_reqs(): output_seqs_list += self.steps_spec_dec() else: while self.request_handler.check_unfinished_reqs(): output_seqs_list += self.step() output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id)) for seq in output_seqs_list: total_tokens_list.append(seq.input_token_id + seq.output_token_id) output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True) if return_token_ids: output_tokens_list = [seq.output_token_id for seq in output_seqs_list] return output_str, output_tokens_list else: return output_str @property def has_prompt_template(self) -> bool: """ """ return self.inference_config.prompt_template is not None def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]: """ This method will format the input prompt according to the prompt template given to the InferenceConfig. """ assert ( self.has_prompt_template ), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig." if isinstance(prompts, (list, tuple)): return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts] elif isinstance(prompts, str): return self.inference_config.prompt_template.format(input_text=prompts) else: raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.") def add_request( self, request_ids: Union[List[int], int] = None, prompts: Union[List[str], str] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, **kwargs, ) -> None: """ Add requests. Args: request_ids (List[int], optional): The request ID. Defaults to None. prompts (Union[List[str], optional): Input prompts. Defaults to None. prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. """ # apply the prompt template to the input prompts if self.has_prompt_template and prompts is not None: prompts = self.format_prompt(prompts) block_size = self.inference_config.block_size if request_ids is not None and not isinstance(request_ids, list): request_ids = [request_ids] if prompts is not None and not isinstance(prompts, list): prompts = [prompts] if prompts_token_ids is None: assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[ "input_ids" ] # list of torch Tensor if isinstance(prompts_token_ids, list): if isinstance(prompts_token_ids[0], torch.Tensor): prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids] elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): prompts_token_ids = prompts_token_ids.tolist() else: raise TypeError( f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}." ) assert ( len(prompts_token_ids[0]) <= self.inference_config.max_input_len ), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}." prompts_num = len(prompts_token_ids) for i in range(prompts_num): if request_ids: assert isinstance( request_ids[0], int ), f"The request_id type must be int, but got {type(request_ids[0])}" assert len(request_ids) == prompts_num request_id = request_ids[i] else: request_id = next(self.counter) if prompts == None: prompt = None else: prompt = prompts[i] max_length = kwargs.get("max_length", None) max_new_tokens = kwargs.get("max_new_tokens", None) if max_length is None and max_new_tokens is None: max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len elif max_length is not None: max_new_tokens = max_length - len(prompts_token_ids[i]) if not self.inference_config.enable_streamingllm: assert ( self.inference_config.max_output_len >= max_new_tokens ), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}." sequence = Sequence( request_id, prompt, prompts_token_ids[i], block_size, None, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id, max_output_len=max_new_tokens, ignore_eos=self.inference_config.ignore_eos, ) self.request_handler.add_sequence(sequence) def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]: input_ids = batch.get_1D_inputs() sequence_lengths = batch.get_sequence_lengths() if batch.is_prompts: n_tokens = sequence_lengths.sum().item() else: n_tokens = batch.current_batch_size if batch.use_spec_dec: n_tokens = batch.num_tokens_to_verify + 1 assert n_tokens == input_ids.size(0) n_tokens = n_tokens * batch.current_batch_size output_tensor = torch.zeros( (n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device ) batch_token_ids = None if ( self.generation_config.repetition_penalty != 1.0 or self.generation_config.no_repeat_ngram_size > 0 or self.generation_config.forced_eos_token_id is not None ): batch_token_ids = batch.batch_token_ids # only when we have the graph for specific decoding batch size can we use the cuda graph for inference use_cuda_graph = False if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys(): use_cuda_graph = True input_meta_data = InputMetaData( block_tables=batch.get_block_table_tensor(), sequence_lengths=sequence_lengths, fd_inter_tensor=batch.fd_inter_tensor, batch_size=batch.current_batch_size, is_prompts=batch.is_prompts, use_cuda_kernel=self.inference_config.use_cuda_kernel, use_cuda_graph=use_cuda_graph, high_precision=self.high_precision, kv_seq_len=sequence_lengths.max().item(), head_dim=batch.head_dim, dtype=batch.dtype, use_spec_dec=batch.use_spec_dec, num_tokens_to_verify=batch.num_tokens_to_verify, batch_token_ids=batch_token_ids, ) return input_ids, output_tensor, input_meta_data def step(self) -> List[str]: """ In each step, do the follows: 1. Run RequestHandler.schedule() and get the batch used for inference. 2. Get the input, inputinfo and output placeholder from the batchbucket 3. Run model to generate the next token 4. Update waiting list and running list in RequestHandler and get finished sequences. 5. Decode and return finished sequences. Returns: List[str]: Decoded finished sequences generated by one step. """ batch = self.request_handler.schedule() input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) if input_meta_data.use_cuda_graph: model_executable = self.graph_runners[input_meta_data.batch_size] else: model_executable = self.model # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) if self.inference_config.pad_input: logits = logits[:, -1, :] if self.inference_config.enable_streamingllm: updated_block_ids = batch.streamingllm_update_batch( self.inference_config.start_token_size, self.inference_config.generated_token_size ) self.request_handler.streamingllm_free_block_tables(updated_block_ids) next_tokens = search_tokens( self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids ) self.request_handler.append_next_tokens(next_tokens) finished_sequences = self.request_handler.update() return finished_sequences