mirror of https://github.com/hpcaitech/ColossalAI
[fix] remove unused comment
parent
9fe61b4475
commit
ff4998c6f3
|
@ -127,7 +127,7 @@ class InferenceConfig:
|
|||
use_cuda_kernel: bool = False
|
||||
|
||||
# cuda_graph
|
||||
use_cuda_graph: bool = False
|
||||
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||
max_context_len_to_capture: int = 512
|
||||
|
||||
def __post_init__(self):
|
||||
|
|
|
@ -101,7 +101,7 @@ class InferenceEngine:
|
|||
self.capture_model(self.k_cache, self.v_cache)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, k_cache: torch.Tensor, v_cache: torch.Tensor):
|
||||
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
|
||||
assert self.use_cuda_graph, "please turn on the cuda graph"
|
||||
|
||||
if self.verbose:
|
||||
|
@ -395,13 +395,6 @@ class InferenceEngine:
|
|||
head_dim=batch.head_dim,
|
||||
)
|
||||
|
||||
# if not batch.is_prompts:
|
||||
# self.logger.info(f"decoding")
|
||||
# self.logger.info(f"input metadata is: {input_meta_data}")
|
||||
# else:
|
||||
# self.logger.info(f"prefill")
|
||||
# self.logger.info(f"input metadata is: {input_meta_data}")
|
||||
|
||||
return input_ids, output_tensor, input_meta_data
|
||||
|
||||
def step(self) -> List[str]:
|
||||
|
@ -423,17 +416,12 @@ class InferenceEngine:
|
|||
|
||||
if input_meta_data.use_cuda_graph:
|
||||
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||
# self.logger.info("run cuda graph")
|
||||
else:
|
||||
model_executable = self.model
|
||||
# self.logger.info("run original model")
|
||||
|
||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
|
||||
# logits_ = self.model(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
# assert torch.all(logits == logits_), f"error! not equal between origin model({logits_[-1]}) and CUDA Graph({logits[-1]})"
|
||||
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
|
|
Loading…
Reference in New Issue