[fix] remove unused comment

pull/5434/head
Runyu Lu 2024-03-25 12:00:57 +08:00
parent 9fe61b4475
commit ff4998c6f3
2 changed files with 2 additions and 14 deletions

View File

@ -127,7 +127,7 @@ class InferenceConfig:
use_cuda_kernel: bool = False
# cuda_graph
use_cuda_graph: bool = False
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
max_context_len_to_capture: int = 512
def __post_init__(self):

View File

@ -101,7 +101,7 @@ class InferenceEngine:
self.capture_model(self.k_cache, self.v_cache)
@torch.inference_mode()
def capture_model(self, k_cache: torch.Tensor, v_cache: torch.Tensor):
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
assert self.use_cuda_graph, "please turn on the cuda graph"
if self.verbose:
@ -395,13 +395,6 @@ class InferenceEngine:
head_dim=batch.head_dim,
)
# if not batch.is_prompts:
# self.logger.info(f"decoding")
# self.logger.info(f"input metadata is: {input_meta_data}")
# else:
# self.logger.info(f"prefill")
# self.logger.info(f"input metadata is: {input_meta_data}")
return input_ids, output_tensor, input_meta_data
def step(self) -> List[str]:
@ -423,17 +416,12 @@ class InferenceEngine:
if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
# self.logger.info("run cuda graph")
else:
model_executable = self.model
# self.logger.info("run original model")
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
# logits_ = self.model(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
# assert torch.all(logits == logits_), f"error! not equal between origin model({logits_[-1]}) and CUDA Graph({logits[-1]})"
if self.inference_config.pad_input:
logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits)