diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index 6131dacc3..33903f426 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -94,6 +94,8 @@ inference_config = InferenceConfig( max_batch_size=4, max_input_len=1024, max_output_len=512, + use_cuda_kernel=True, + use_cuda_graph=False, # Turn on if you want to use CUDA Graph to accelerate inference ) # Step 3: create an engine with model and config diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 7b49e8f77..aad0310cb 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -10,11 +10,12 @@ import torch import torch.distributed as dist from transformers.generation import GenerationConfig +from colossalai.inference.flash_decoding_utils import FDIntermTensors + GibiByte = 1024**3 logger = logging.Logger(__name__) - _DTYPE_MAPPING = { "fp16": torch.float16, "bf16": torch.bfloat16, @@ -23,13 +24,42 @@ _DTYPE_MAPPING = { _ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32] - _DEFAULT_PROMPT_TEMPLATES = { "llama": "[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<>\n{input_text}[/INST]", "vicuna": "USER: {input_text}\n\nASSISTANT: ", } +@dataclass +class InputMetaData: + """The input info for a single step + + Args: + block_tables (torch.Tensor, optional): Sequences' BlockTables Defaults to None. + sequence_lengths (torch.Tensor): A tensor containing sequence lengths. + fd_inter_tensor (torch.Tensor, optional): A tensor representing intermediate data for flash decoding. Defaults to None. + batch_size (int, optional): The current batch size. Defaults to 64. + is_prompts (bool, optional): Indicates whether prefill or decoding. Defaults to False(decoding). + use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally + use_cuda_graph (bool, optional): Indicates whether to use the CUDA graph. Defaults to False. + kv_seq_len (int, optional): Key-value sequence length. Defaults to 512. + head_dim (int, optional): Head dimension. Defaults to 32. + """ + + block_tables: torch.Tensor = None + sequence_lengths: torch.Tensor = None + fd_inter_tensor: FDIntermTensors = None + batch_size: int = 64 # current_batch_size + is_prompts: bool = False + use_cuda_kernel: bool = False + use_cuda_graph: bool = False + kv_seq_len: int = 512 + head_dim: int = 32 + + def __repr__(self) -> str: + return f"InputMetaData(block_tables={self.block_tables}, sequence_lengths={self.sequence_lengths}, fd_inter_tensor={self.fd_inter_tensor}, batch_size={self.batch_size}, is_prompts={self.is_prompts}, use_cuda_graph={self.use_cuda_graph}, kv_seq_len={self.kv_seq_len}, head_dim={self.head_dim})" + + @dataclass class InferenceConfig: """The inference configuration. @@ -55,6 +85,9 @@ class InferenceConfig: pp_size (int): Pipeline parallel size, defaults to 1. micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. + use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally + use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid. + max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. """ @@ -91,7 +124,15 @@ class InferenceConfig: micro_batch_buffer_size: int = None high_precision: Optional[bool] = False + # cuda kernel option + use_cuda_kernel: bool = False + + # cuda_graph + use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference + max_context_len_to_capture: int = 512 + def __post_init__(self): + self.max_context_len_to_capture = self.max_input_len + self.max_output_len self._verify_config() def _verify_config(self) -> None: diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 4833e5b0c..a2388121b 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -1,5 +1,6 @@ +import time from itertools import count -from typing import List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -7,7 +8,9 @@ import torch.nn as nn from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast from colossalai.cluster import ProcessGroupMesh -from colossalai.inference.config import InferenceConfig +from colossalai.inference.batch_bucket import BatchBucket +from colossalai.inference.config import InferenceConfig, InputMetaData +from colossalai.inference.graph_runner import CUDAGraphRunner from colossalai.inference.modeling.policy import model_policy_map from colossalai.inference.struct import Sequence from colossalai.logging import get_dist_logger @@ -25,6 +28,8 @@ _supported_models = [ "LlamaForCausalLM", ] +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] + class InferenceEngine: @@ -82,11 +87,93 @@ class InferenceEngine: self.logger = get_dist_logger(__name__) self.request_handler = RequestHandler(self.inference_config, self.model_config) - self.k_cahce, self.v_cache = self.request_handler.get_kvcache() + self.k_cache, self.v_cache = self.request_handler.get_kvcache() # DISCUSS maybe move this into batch info? self.counter = count() + self.use_cuda_graph = self.inference_config.use_cuda_graph + if self.use_cuda_graph: + self.graph_runners: Dict[int, CUDAGraphRunner] = {} + self.graph_memory_pool = None # Set during graph capture. + if verbose: + self.logger.info("Colossal AI CUDA Graph Capture on") + + self.capture_model(self.k_cache, self.v_cache) + + @torch.inference_mode() + def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]): + assert self.use_cuda_graph, "please turn on the cuda graph" + + if self.verbose: + self.logger.info("Colossal AI CUDA Graph Capture begin") + + t_capture_begin = time.perf_counter() + + block_size = self.inference_config.block_size + head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads + + # Prepare dummy inputs. These will be reused for all batch sizes. + max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + max_context_len_to_capture = self.inference_config.max_context_len_to_capture + max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size + input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda() + # self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) + self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32) + self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE)) + self.graph_block_tables[0, :] = np.arange( + 0, max_num_blocks + ) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len + block_tables = torch.from_numpy(self.graph_block_tables).cuda() + output_tensor = torch.zeros( + (max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device + ) + fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor + + max_num_seqs = self.inference_config.max_batch_size + batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs] + sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda() + # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len + sequence_lengths[0] = torch.tensor( + self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32 + ).cuda() + + # NOTE: Capturing the largest batch size first may help reduce the + # memory usage of CUDA graph. + for batch_size in reversed(batch_size_capture_list): + if self.verbose: + self.logger.info(f"batch size {batch_size} graph capturing") + + input_meta_data = InputMetaData( + block_tables=block_tables[:batch_size], + sequence_lengths=sequence_lengths[:batch_size], + fd_inter_tensor=fd_inter_tensor, + batch_size=batch_size, + is_prompts=False, + use_cuda_graph=True, + high_precision=False, + kv_seq_len=sequence_lengths[:batch_size].max().item(), + head_dim=head_dim, + dtype=self.dtype, + ) + + graph_runner = CUDAGraphRunner(self.model) + graph_runner.capture( + input_tokens_ids[:batch_size], + output_tensor[:batch_size], + input_meta_data, + k_caches=k_cache, + v_caches=v_cache, + memory_pool=self.graph_memory_pool, + ) + self.graph_memory_pool = graph_runner.graph.pool() + self.graph_runners[batch_size] = graph_runner + + t_capture_end = time.perf_counter() + + if self.verbose: + self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s") + def _verify_config(self) -> None: """ Verify the input config @@ -279,13 +366,50 @@ class InferenceEngine: ) self.request_handler.add_sequence(sequence) + def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]: + input_ids = batch.get_1D_inputs() + + sequence_lengths = batch.get_sequence_lengths() + if batch.is_prompts: + output_tensor = torch.zeros( + (sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), + dtype=batch.dtype, + device=batch.device, + ) + else: + output_tensor = torch.zeros( + (batch.current_batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device + ) + + # only when we have the graph for specific decoding batch size can we use the cuda graph for inference + use_cuda_graph = False + if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys(): + use_cuda_graph = True + + input_meta_data = InputMetaData( + block_tables=batch.get_block_table_tensor(), + sequence_lengths=sequence_lengths, + fd_inter_tensor=batch.fd_inter_tensor, + batch_size=batch.current_batch_size, + is_prompts=batch.is_prompts, + use_cuda_kernel=self.inference_config.use_cuda_kernel, + use_cuda_graph=use_cuda_graph, + high_precision=self.high_precision, + kv_seq_len=sequence_lengths.max().item(), + head_dim=batch.head_dim, + dtype=batch.dtype, + ) + + return input_ids, output_tensor, input_meta_data + def step(self) -> List[str]: """ In each step, do the follows: 1. Run RequestHandler.schedule() and get the batch used for inference. - 2. Run model to generate the next token - 3. Update waiting list and running list in RequestHandler and get finished sequences. - 4. Decode and return finished sequences. + 2. Get the input, inputinfo and output placeholder from the batchbucket + 3. Run model to generate the next token + 4. Update waiting list and running list in RequestHandler and get finished sequences. + 5. Decode and return finished sequences. Returns: List[str]: Decoded finished sequences generated by one step. @@ -293,14 +417,15 @@ class InferenceEngine: batch = self.request_handler.schedule() - # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. - logits = self.model( - batch, - self.k_cahce, - self.v_cache, - self.high_precision, - ) + input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) + if input_meta_data.use_cuda_graph: + model_executable = self.graph_runners[input_meta_data.batch_size] + else: + model_executable = self.model + + # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. + logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) if self.inference_config.pad_input: logits = logits[:, -1, :] self.request_handler.search_tokens(self.generation_config, logits) diff --git a/colossalai/inference/graph_runner.py b/colossalai/inference/graph_runner.py new file mode 100644 index 000000000..e8b805574 --- /dev/null +++ b/colossalai/inference/graph_runner.py @@ -0,0 +1,100 @@ +from typing import Dict, List + +import torch +from torch import nn + +from colossalai.inference.config import InputMetaData +from colossalai.logging import get_dist_logger + + +class CUDAGraphRunner: + def __init__(self, model: nn.Module): + self.model = model + self.graph = None + self.input_buffers: Dict[str, torch.Tensor] = {} + self.output_buffers: Dict[str, torch.Tensor] = {} + self.logger = get_dist_logger(__name__) + + def capture( + self, + input_tokens_ids: torch.Tensor, + output_tensor: torch.Tensor, + inputmetadata: InputMetaData, + k_caches: List[torch.Tensor] = None, + v_caches: List[torch.Tensor] = None, + memory_pool=None, + ) -> None: + assert self.graph is None + + # run kernel once to cache the kernel, avoid stream capture error + hidden_states_origin_model = self.model( + input_tokens_ids, + output_tensor, + inputmetadata, + k_caches, + v_caches, + ) + torch.cuda.synchronize() + + # Capture the graph. + # self.logger.info(f"begin capture model...") + self.graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.graph, pool=memory_pool): + hidden_states_cuda_graph = self.model( + input_tokens_ids, + output_tensor, + inputmetadata, + k_caches, + v_caches, + ) + torch.cuda.synchronize() + + # Save the input and output buffers, because replay always uses the same virtual memory space + self.input_buffers = { + "input_tokens_ids": input_tokens_ids, + "output_tensor": output_tensor, + "block_tables": inputmetadata.block_tables, + "sequence_lengths": inputmetadata.sequence_lengths, + # "fd_inter_tensor_mid_output": inputmetadata.fd_inter_tensor._mid_output, + # "fd_inter_tensor_mid_output_lse": inputmetadata.fd_inter_tensor._mid_output_lse, + "k_caches": k_caches, + "v_caches": v_caches, + } + self.output_buffers = {"logits": hidden_states_cuda_graph} + return + + def forward( + self, + input_tokens_ids: torch.Tensor, + output_tensor: torch.Tensor, + inputmetadata: InputMetaData, + k_caches: List[torch.Tensor] = None, + v_caches: List[torch.Tensor] = None, + ) -> torch.Tensor: + # Copy the input tensors to the input buffers. + self.input_buffers["input_tokens_ids"].copy_(input_tokens_ids, non_blocking=True) + self.input_buffers["output_tensor"].copy_(output_tensor, non_blocking=True) + + # for flexible block_table + self.input_buffers["block_tables"].fill_(-1) + M, N = inputmetadata.block_tables.shape + self.input_buffers["block_tables"][:M, :N].copy_(inputmetadata.block_tables, non_blocking=True) + + self.input_buffers["sequence_lengths"].copy_(inputmetadata.sequence_lengths, non_blocking=True) + + # we only have a global fd_inter_tensor so we don't need to copy them + # self.input_buffers["fd_inter_tensor_mid_output"].copy_(inputmetadata.fd_inter_tensor.mid_output, non_blocking=True) + # self.input_buffers["fd_inter_tensor_mid_output_lse"].copy_(inputmetadata.fd_inter_tensor.mid_output_lse, non_blocking=True) + + # KV caches are fixed tensors, so we don't need to copy them. + # self.input_buffers["k_caches"].copy_(k_caches, non_blocking=True) + # self.input_buffers["v_caches"].copy_(v_caches, non_blocking=True) + + # Run the graph. + self.graph.replay() + + # Return the output tensor. + return self.output_buffers["logits"] + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 9ea79551e..37a714c83 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -13,7 +13,7 @@ from transformers.models.llama.modeling_llama import ( LlamaRMSNorm, ) -from colossalai.inference.batch_bucket import BatchBucket +from colossalai.inference.config import InputMetaData from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.triton import ( @@ -41,11 +41,12 @@ except ImportError: def llama_causal_lm_forward( self: LlamaForCausalLM, - batch: BatchBucket, - k_caches: List[torch.Tensor], - v_caches: List[torch.Tensor], - high_precision: bool = False, -): + input_tokens_ids: torch.Tensor, + output_tensor: torch.Tensor, + inputmetadata: InputMetaData, + k_caches: List[torch.Tensor] = None, + v_caches: List[torch.Tensor] = None, +) -> torch.Tensor: """This function will replace the forward function of LlamaForCausalLM. Args: @@ -58,10 +59,13 @@ def llama_causal_lm_forward( # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) hidden_states = llama_model_forward( self.model, - batch=batch, + input_tokens_ids=input_tokens_ids, + output_tensor=output_tensor, + inputmetadata=inputmetadata, k_caches=k_caches, v_caches=v_caches, - high_precision=high_precision, + use_cuda_kernel=inputmetadata.use_cuda_kernel, # Note currently the cuda kernel of layernorm, rotary_embedding_and_cache_copy couldn't pass the unitest but triton kernel could + high_precision=inputmetadata.high_precision, ) logits = torch.mm(hidden_states, self.lm_head.weight) return logits @@ -69,11 +73,14 @@ def llama_causal_lm_forward( def llama_model_forward( self: LlamaModel, - batch: BatchBucket, - k_caches: List[torch.Tensor], - v_caches: List[torch.Tensor], + input_tokens_ids: torch.Tensor, + output_tensor: torch.Tensor, + inputmetadata: InputMetaData, + k_caches: List[torch.Tensor] = None, + v_caches: List[torch.Tensor] = None, + use_cuda_kernel: Optional[bool] = True, high_precision: bool = False, -): +) -> torch.Tensor: """This function will replace the forward function of LlamaModel. Args: @@ -82,36 +89,26 @@ def llama_model_forward( v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache. high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. """ - input_ids = batch.get_1D_inputs() - block_tables = batch.get_block_table_tensor() - sequence_lengths = batch.get_sequence_lengths() - batch_size = batch.current_batch_size - kv_seq_len = sequence_lengths.max().item() - use_cuda_kernel = True + block_tables = inputmetadata.block_tables + sequence_lengths = inputmetadata.sequence_lengths + batch_size = inputmetadata.batch_size + kv_seq_len = inputmetadata.kv_seq_len + # NOTE: After testing, the performance of this configuration is relatively good. With updates # and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's # selection should be conducted. if batch_size >= 32 and kv_seq_len > 512: use_cuda_kernel = False - if use_cuda_kernel and batch.dtype != torch.float32 and use_flash_attn2: + hidden_states = self.embed_tokens(input_tokens_ids) + if use_cuda_kernel and inputmetadata != torch.float32 and use_flash_attn2: cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) else: cu_seqlens = None - hidden_states = self.embed_tokens(input_ids) - - cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts) + cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts) - if batch.is_prompts: - output_tensor = torch.zeros( - (sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device - ) - else: - output_tensor = torch.zeros( - (batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device - ) - sm_scale = 1.0 / (batch.head_dim**0.5) + sm_scale = 1.0 / (inputmetadata.head_dim**0.5) norm_output = torch.empty_like(hidden_states) residual = None @@ -123,10 +120,10 @@ def llama_model_forward( block_tables=block_tables, k_cache=k_caches[layer_id], v_cache=v_caches[layer_id], + is_prompts=inputmetadata.is_prompts, sequence_lengths=sequence_lengths, cos_sin=cos_sin, - fd_inter_tensor=batch.fd_inter_tensor, - is_prompts=batch.is_prompts, + fd_inter_tensor=inputmetadata.fd_inter_tensor, kv_seq_len=kv_seq_len, output_tensor=output_tensor, norm_output=norm_output, @@ -136,7 +133,7 @@ def llama_model_forward( high_precision=high_precision, ) - if batch.is_prompts: + if inputmetadata.is_prompts: last_token_indexs = sequence_lengths.cumsum(dim=-1) hidden_states = hidden_states[last_token_indexs - 1].contiguous() residual = residual[last_token_indexs - 1].contiguous() diff --git a/colossalai/kernel/triton/rms_layernorm.py b/colossalai/kernel/triton/rms_layernorm.py index dcf478561..fb3207503 100644 --- a/colossalai/kernel/triton/rms_layernorm.py +++ b/colossalai/kernel/triton/rms_layernorm.py @@ -1,5 +1,3 @@ -import torch - try: import triton import triton.language as tl @@ -94,7 +92,9 @@ if HAS_TRITON: def rms_layernorm(x, weight, eps, norm_output=None, residual=None): # allocate output - y = torch.empty_like(x) if norm_output is None else norm_output + y = ( + x * 0 if norm_output is None else norm_output + ) # to make the operation non-functional, store y as the intermediate activation M, N = x.shape # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() diff --git a/tests/test_infer/test_cuda_graph.py b/tests/test_infer/test_cuda_graph.py new file mode 100644 index 000000000..cc5f1c7a2 --- /dev/null +++ b/tests/test_infer/test_cuda_graph.py @@ -0,0 +1,96 @@ +import random + +import numpy as np +import pytest +import torch +from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM + +import colossalai +from colossalai.inference.config import InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def check_inference_engine(use_cuda_graph=False, batch_size=32): + setup_seed(20) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + model = ( + LlamaForCausalLM( + LlamaConfig( + vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 + ) + ) + .cuda() + .half() + ) + model = model.eval() + + prompts_token_ids = [] + for i in range(batch_size): + prompts_token_ids.append( + np.random.randint(low=0, high=100, size=random.randint(1, max(1024 // batch_size, 32))).tolist() + ) + + input_len = 1024 + output_len = 128 + do_sample = True + top_p = 0.5 + top_k = 50 + + if use_cuda_graph: + inference_config = InferenceConfig( + max_batch_size=batch_size, + max_input_len=input_len, + max_output_len=output_len, + use_cuda_kernel=False, + use_cuda_graph=True, + block_size=16, + ) + else: + inference_config = InferenceConfig( + max_batch_size=batch_size, + max_input_len=input_len, + max_output_len=output_len, + use_cuda_kernel=False, + use_cuda_graph=False, + block_size=16, + ) + + inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) + assert inference_engine.generation_config.max_new_tokens == output_len + generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) + outputs = inference_engine.generate(prompts_token_ids=prompts_token_ids, generation_config=generation_config) + + return outputs + + +def check_output_consistency(batch_size): + cuda_graph_output = check_inference_engine(use_cuda_graph=True, batch_size=batch_size) + naive_model_output = check_inference_engine(use_cuda_graph=False, batch_size=batch_size) + + for s1, s2 in zip(cuda_graph_output, naive_model_output): + assert s1 == s2, f"\nCUDA Graph Output: {s1}\nOrigin Output: {s2}" + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + check_output_consistency(32) + check_output_consistency(64) + check_output_consistency(128) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_cuda_graph_infer(): + spawn(run_dist, 1) + + +if __name__ == "__main__": + test_cuda_graph_infer()