|
|
|
@ -84,6 +84,7 @@ def llama_model_forward(
|
|
|
|
|
sequence_lengths = inputmetadata.sequence_lengths
|
|
|
|
|
batch_size = inputmetadata.batch_size
|
|
|
|
|
kv_seq_len = inputmetadata.kv_seq_len
|
|
|
|
|
# use_cuda_kernel = False
|
|
|
|
|
use_cuda_kernel = True
|
|
|
|
|
# NOTE: After testing, the performance of this configuration is relatively good. With updates
|
|
|
|
|
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
|
|
|
|
@ -97,7 +98,7 @@ def llama_model_forward(
|
|
|
|
|
|
|
|
|
|
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
|
|
|
|
|
|
|
|
|
|
norm_output = None
|
|
|
|
|
norm_output = torch.empty_like(hidden_states)
|
|
|
|
|
residual = None
|
|
|
|
|
|
|
|
|
|
for layer_id, decoder_layer in enumerate(self.layers):
|
|
|
|
@ -122,10 +123,9 @@ def llama_model_forward(
|
|
|
|
|
last_token_indexs = sequence_lengths.cumsum(dim=-1)
|
|
|
|
|
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
|
|
|
|
residual = residual[last_token_indexs - 1].contiguous()
|
|
|
|
|
norm_output = torch.empty_like(hidden_states) # NOTE non-functional, but cuda graph only capture decoding only
|
|
|
|
|
norm_output = torch.empty_like(hidden_states)
|
|
|
|
|
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -198,7 +198,8 @@ def llama_rmsnorm_forward(
|
|
|
|
|
residual: torch.Tensor = None,
|
|
|
|
|
use_cuda_kernel: bool = True,
|
|
|
|
|
):
|
|
|
|
|
if use_cuda_kernel:
|
|
|
|
|
# if use_cuda_kernel:
|
|
|
|
|
if False:
|
|
|
|
|
if residual is not None:
|
|
|
|
|
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, self.variance_epsilon)
|
|
|
|
|
return hidden_states, residual
|
|
|
|
@ -338,7 +339,8 @@ class NopadLlamaAttention(LlamaAttention):
|
|
|
|
|
sm_scale=sm_scale,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
if use_cuda_kernel:
|
|
|
|
|
# if use_cuda_kernel:
|
|
|
|
|
if False:
|
|
|
|
|
inference_ops.rotary_embedding_and_cache_copy(
|
|
|
|
|
query_states,
|
|
|
|
|
key_states,
|
|
|
|
|