[feat] add use_cuda_kernel option

pull/5434/head
Runyu Lu 2024-03-19 13:24:25 +08:00
parent 6e30248683
commit aabc9fb6aa
3 changed files with 11 additions and 2 deletions

View File

@ -40,6 +40,7 @@ class InputMetaData:
fd_inter_tensor (torch.Tensor, optional): A tensor representing intermediate data for flash decoding. Defaults to None.
batch_size (int, optional): The current batch size. Defaults to 64.
is_prompts (bool, optional): Indicates whether prefill or decoding. Defaults to False(decoding).
use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally
use_cuda_graph (bool, optional): Indicates whether to use the CUDA graph. Defaults to False.
kv_seq_len (int, optional): Key-value sequence length. Defaults to 512.
head_dim (int, optional): Head dimension. Defaults to 32.
@ -50,6 +51,7 @@ class InputMetaData:
fd_inter_tensor: FDIntermTensors = None
batch_size: int = 64 # current_batch_size
is_prompts: bool = False
use_cuda_kernel: bool = False
use_cuda_graph: bool = False
kv_seq_len: int = 512
head_dim: int = 32
@ -83,6 +85,7 @@ class InferenceConfig:
pp_size (int): Pipeline parallel size, defaults to 1.
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence
@ -120,6 +123,9 @@ class InferenceConfig:
micro_batch_size: int = 1
micro_batch_buffer_size: int = None
# cuda kernel option
use_cuda_kernel: bool = False
# cuda_graph
use_cuda_graph: bool = False
max_context_len_to_capture: int = 512

View File

@ -60,6 +60,7 @@ def llama_causal_lm_forward(
inputmetadata=inputmetadata,
k_caches=k_caches,
v_caches=v_caches,
use_cuda_kernel=inputmetadata.use_cuda_kernel, # Note currently the cuda kernel of layernorm, rotary_embedding_and_cache_copy couldn't pass the unitest but triton kernel could
)
logits = torch.mm(hidden_states, self.lm_head.weight)
return logits
@ -72,6 +73,7 @@ def llama_model_forward(
inputmetadata: InputMetaData,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
use_cuda_kernel: Optional[bool] = True,
) -> torch.Tensor:
"""This function will replace the forward function of LlamaModel.
@ -84,8 +86,7 @@ def llama_model_forward(
sequence_lengths = inputmetadata.sequence_lengths
batch_size = inputmetadata.batch_size
kv_seq_len = inputmetadata.kv_seq_len
# use_cuda_kernel = False
use_cuda_kernel = True
# NOTE: After testing, the performance of this configuration is relatively good. With updates
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
# selection should be conducted.

View File

@ -49,6 +49,7 @@ def check_inference_engine(use_cuda_graph=False, batch_size=32):
max_batch_size=batch_size,
max_input_len=input_len,
max_output_len=output_len,
use_cuda_kernel=False,
use_cuda_graph=True,
block_size=16,
)
@ -57,6 +58,7 @@ def check_inference_engine(use_cuda_graph=False, batch_size=32):
max_batch_size=batch_size,
max_input_len=input_len,
max_output_len=output_len,
use_cuda_kernel=False,
use_cuda_graph=False,
block_size=16,
)