From 6e30248683c0e4ccc63d15f39f8149875cba1263 Mon Sep 17 00:00:00 2001 From: Runyu Lu Date: Thu, 14 Mar 2024 16:13:00 +0800 Subject: [PATCH] [fix] tmp for test --- .../inference/modeling/models/nopadding_llama.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 108b79174..29760f564 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -84,6 +84,7 @@ def llama_model_forward( sequence_lengths = inputmetadata.sequence_lengths batch_size = inputmetadata.batch_size kv_seq_len = inputmetadata.kv_seq_len + # use_cuda_kernel = False use_cuda_kernel = True # NOTE: After testing, the performance of this configuration is relatively good. With updates # and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's @@ -97,7 +98,7 @@ def llama_model_forward( sm_scale = 1.0 / (inputmetadata.head_dim**0.5) - norm_output = None + norm_output = torch.empty_like(hidden_states) residual = None for layer_id, decoder_layer in enumerate(self.layers): @@ -122,10 +123,9 @@ def llama_model_forward( last_token_indexs = sequence_lengths.cumsum(dim=-1) hidden_states = hidden_states[last_token_indexs - 1].contiguous() residual = residual[last_token_indexs - 1].contiguous() - norm_output = torch.empty_like(hidden_states) # NOTE non-functional, but cuda graph only capture decoding only + norm_output = torch.empty_like(hidden_states) hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel) - return hidden_states @@ -198,7 +198,8 @@ def llama_rmsnorm_forward( residual: torch.Tensor = None, use_cuda_kernel: bool = True, ): - if use_cuda_kernel: + # if use_cuda_kernel: + if False: if residual is not None: inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, self.variance_epsilon) return hidden_states, residual @@ -338,7 +339,8 @@ class NopadLlamaAttention(LlamaAttention): sm_scale=sm_scale, ) else: - if use_cuda_kernel: + # if use_cuda_kernel: + if False: inference_ops.rotary_embedding_and_cache_copy( query_states, key_states,