|
|
|
@ -84,6 +84,7 @@ def llama_model_forward(
|
|
|
|
|
sequence_lengths = inputmetadata.sequence_lengths |
|
|
|
|
batch_size = inputmetadata.batch_size |
|
|
|
|
kv_seq_len = inputmetadata.kv_seq_len |
|
|
|
|
# use_cuda_kernel = False |
|
|
|
|
use_cuda_kernel = True |
|
|
|
|
# NOTE: After testing, the performance of this configuration is relatively good. With updates |
|
|
|
|
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's |
|
|
|
@ -97,7 +98,7 @@ def llama_model_forward(
|
|
|
|
|
|
|
|
|
|
sm_scale = 1.0 / (inputmetadata.head_dim**0.5) |
|
|
|
|
|
|
|
|
|
norm_output = None |
|
|
|
|
norm_output = torch.empty_like(hidden_states) |
|
|
|
|
residual = None |
|
|
|
|
|
|
|
|
|
for layer_id, decoder_layer in enumerate(self.layers): |
|
|
|
@ -122,10 +123,9 @@ def llama_model_forward(
|
|
|
|
|
last_token_indexs = sequence_lengths.cumsum(dim=-1) |
|
|
|
|
hidden_states = hidden_states[last_token_indexs - 1].contiguous() |
|
|
|
|
residual = residual[last_token_indexs - 1].contiguous() |
|
|
|
|
norm_output = torch.empty_like(hidden_states) # NOTE non-functional, but cuda graph only capture decoding only |
|
|
|
|
norm_output = torch.empty_like(hidden_states) |
|
|
|
|
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -198,7 +198,8 @@ def llama_rmsnorm_forward(
|
|
|
|
|
residual: torch.Tensor = None, |
|
|
|
|
use_cuda_kernel: bool = True, |
|
|
|
|
): |
|
|
|
|
if use_cuda_kernel: |
|
|
|
|
# if use_cuda_kernel: |
|
|
|
|
if False: |
|
|
|
|
if residual is not None: |
|
|
|
|
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, self.variance_epsilon) |
|
|
|
|
return hidden_states, residual |
|
|
|
@ -338,7 +339,8 @@ class NopadLlamaAttention(LlamaAttention):
|
|
|
|
|
sm_scale=sm_scale, |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
if use_cuda_kernel: |
|
|
|
|
# if use_cuda_kernel: |
|
|
|
|
if False: |
|
|
|
|
inference_ops.rotary_embedding_and_cache_copy( |
|
|
|
|
query_states, |
|
|
|
|
key_states, |
|
|
|
|