mirror of https://github.com/hpcaitech/ColossalAI
[fix] remove unused comment
parent
9fe61b4475
commit
ff4998c6f3
|
@ -127,7 +127,7 @@ class InferenceConfig:
|
||||||
use_cuda_kernel: bool = False
|
use_cuda_kernel: bool = False
|
||||||
|
|
||||||
# cuda_graph
|
# cuda_graph
|
||||||
use_cuda_graph: bool = False
|
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||||
max_context_len_to_capture: int = 512
|
max_context_len_to_capture: int = 512
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|
|
@ -101,7 +101,7 @@ class InferenceEngine:
|
||||||
self.capture_model(self.k_cache, self.v_cache)
|
self.capture_model(self.k_cache, self.v_cache)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def capture_model(self, k_cache: torch.Tensor, v_cache: torch.Tensor):
|
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
|
||||||
assert self.use_cuda_graph, "please turn on the cuda graph"
|
assert self.use_cuda_graph, "please turn on the cuda graph"
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -395,13 +395,6 @@ class InferenceEngine:
|
||||||
head_dim=batch.head_dim,
|
head_dim=batch.head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
# if not batch.is_prompts:
|
|
||||||
# self.logger.info(f"decoding")
|
|
||||||
# self.logger.info(f"input metadata is: {input_meta_data}")
|
|
||||||
# else:
|
|
||||||
# self.logger.info(f"prefill")
|
|
||||||
# self.logger.info(f"input metadata is: {input_meta_data}")
|
|
||||||
|
|
||||||
return input_ids, output_tensor, input_meta_data
|
return input_ids, output_tensor, input_meta_data
|
||||||
|
|
||||||
def step(self) -> List[str]:
|
def step(self) -> List[str]:
|
||||||
|
@ -423,17 +416,12 @@ class InferenceEngine:
|
||||||
|
|
||||||
if input_meta_data.use_cuda_graph:
|
if input_meta_data.use_cuda_graph:
|
||||||
model_executable = self.graph_runners[input_meta_data.batch_size]
|
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||||
# self.logger.info("run cuda graph")
|
|
||||||
else:
|
else:
|
||||||
model_executable = self.model
|
model_executable = self.model
|
||||||
# self.logger.info("run original model")
|
|
||||||
|
|
||||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||||
|
|
||||||
# logits_ = self.model(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
|
||||||
# assert torch.all(logits == logits_), f"error! not equal between origin model({logits_[-1]}) and CUDA Graph({logits[-1]})"
|
|
||||||
|
|
||||||
if self.inference_config.pad_input:
|
if self.inference_config.pad_input:
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
self.request_handler.search_tokens(self.generation_config, logits)
|
self.request_handler.search_tokens(self.generation_config, logits)
|
||||||
|
|
Loading…
Reference in New Issue