diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 8dcdddf61..4e429f7b8 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -127,7 +127,7 @@ class InferenceConfig: use_cuda_kernel: bool = False # cuda_graph - use_cuda_graph: bool = False + use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference max_context_len_to_capture: int = 512 def __post_init__(self): diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 6b7c99300..e7bd1add7 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -101,7 +101,7 @@ class InferenceEngine: self.capture_model(self.k_cache, self.v_cache) @torch.inference_mode() - def capture_model(self, k_cache: torch.Tensor, v_cache: torch.Tensor): + def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]): assert self.use_cuda_graph, "please turn on the cuda graph" if self.verbose: @@ -395,13 +395,6 @@ class InferenceEngine: head_dim=batch.head_dim, ) - # if not batch.is_prompts: - # self.logger.info(f"decoding") - # self.logger.info(f"input metadata is: {input_meta_data}") - # else: - # self.logger.info(f"prefill") - # self.logger.info(f"input metadata is: {input_meta_data}") - return input_ids, output_tensor, input_meta_data def step(self) -> List[str]: @@ -423,17 +416,12 @@ class InferenceEngine: if input_meta_data.use_cuda_graph: model_executable = self.graph_runners[input_meta_data.batch_size] - # self.logger.info("run cuda graph") else: model_executable = self.model - # self.logger.info("run original model") # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) - # logits_ = self.model(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) - # assert torch.all(logits == logits_), f"error! not equal between origin model({logits_[-1]}) and CUDA Graph({logits[-1]})" - if self.inference_config.pad_input: logits = logits[:, -1, :] self.request_handler.search_tokens(self.generation_config, logits)