|
|
@ -101,6 +101,13 @@ def llama_model_forward( |
|
|
|
if batch_size >= 32 and kv_seq_len > 512: |
|
|
|
if batch_size >= 32 and kv_seq_len > 512: |
|
|
|
use_cuda_kernel = False |
|
|
|
use_cuda_kernel = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# NOTE (yuanheng-zhao): fow now, only triton kernels support verification process |
|
|
|
|
|
|
|
# during speculative-decoding (`q_len > 1`) |
|
|
|
|
|
|
|
# We will expicitly disable `use_cuda_kernel` here when speculative-decoding is enabled |
|
|
|
|
|
|
|
if inputmetadata.use_spec_dec and use_cuda_kernel: |
|
|
|
|
|
|
|
use_cuda_kernel = False |
|
|
|
|
|
|
|
logger.warning("CUDA kernel is disabled for speculative-decoding.") |
|
|
|
|
|
|
|
|
|
|
|
hidden_states = self.embed_tokens(input_tokens_ids) |
|
|
|
hidden_states = self.embed_tokens(input_tokens_ids) |
|
|
|
cu_seqlens = None |
|
|
|
cu_seqlens = None |
|
|
|
|
|
|
|
|
|
|
@ -415,6 +422,8 @@ class NopadLlamaAttention(LlamaAttention): |
|
|
|
sm_scale=sm_scale, |
|
|
|
sm_scale=sm_scale, |
|
|
|
) |
|
|
|
) |
|
|
|
else: |
|
|
|
else: |
|
|
|
|
|
|
|
q_len = tokens_to_verify + 1 if is_verifier else 1 |
|
|
|
|
|
|
|
|
|
|
|
if use_cuda_kernel: |
|
|
|
if use_cuda_kernel: |
|
|
|
inference_ops.rotary_embedding_and_cache_copy( |
|
|
|
inference_ops.rotary_embedding_and_cache_copy( |
|
|
|
query_states, |
|
|
|
query_states, |
|
|
@ -429,7 +438,6 @@ class NopadLlamaAttention(LlamaAttention): |
|
|
|
high_precision, |
|
|
|
high_precision, |
|
|
|
) |
|
|
|
) |
|
|
|
else: |
|
|
|
else: |
|
|
|
q_len = tokens_to_verify + 1 if is_verifier else 1 |
|
|
|
|
|
|
|
if is_verifier: |
|
|
|
if is_verifier: |
|
|
|
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) |
|
|
|
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) |
|
|
|
copy_k_to_blocked_cache( |
|
|
|
copy_k_to_blocked_cache( |
|
|
|