mirror of https://github.com/hpcaitech/ColossalAI
[fix] tmp for test
parent
d02e257abd
commit
6e30248683
|
@ -84,6 +84,7 @@ def llama_model_forward(
|
|||
sequence_lengths = inputmetadata.sequence_lengths
|
||||
batch_size = inputmetadata.batch_size
|
||||
kv_seq_len = inputmetadata.kv_seq_len
|
||||
# use_cuda_kernel = False
|
||||
use_cuda_kernel = True
|
||||
# NOTE: After testing, the performance of this configuration is relatively good. With updates
|
||||
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
|
||||
|
@ -97,7 +98,7 @@ def llama_model_forward(
|
|||
|
||||
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
|
||||
|
||||
norm_output = None
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
residual = None
|
||||
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
|
@ -122,10 +123,9 @@ def llama_model_forward(
|
|||
last_token_indexs = sequence_lengths.cumsum(dim=-1)
|
||||
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
||||
residual = residual[last_token_indexs - 1].contiguous()
|
||||
norm_output = torch.empty_like(hidden_states) # NOTE non-functional, but cuda graph only capture decoding only
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
|
||||
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
@ -198,7 +198,8 @@ def llama_rmsnorm_forward(
|
|||
residual: torch.Tensor = None,
|
||||
use_cuda_kernel: bool = True,
|
||||
):
|
||||
if use_cuda_kernel:
|
||||
# if use_cuda_kernel:
|
||||
if False:
|
||||
if residual is not None:
|
||||
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, self.variance_epsilon)
|
||||
return hidden_states, residual
|
||||
|
@ -338,7 +339,8 @@ class NopadLlamaAttention(LlamaAttention):
|
|||
sm_scale=sm_scale,
|
||||
)
|
||||
else:
|
||||
if use_cuda_kernel:
|
||||
# if use_cuda_kernel:
|
||||
if False:
|
||||
inference_ops.rotary_embedding_and_cache_copy(
|
||||
query_states,
|
||||
key_states,
|
||||
|
|
Loading…
Reference in New Issue