[Fix] Llama Modeling Control with Spec-Dec (#5580)

- fix ref before asgmt
- fall back to use triton kernels when using spec-dec
feat/speculative-decoding
Yuanheng 2024-04-10 11:14:04 +08:00 committed by ocd_with_naming
parent e60d430cf5
commit f8598e3ec5
1 changed files with 9 additions and 1 deletions

View File

@ -101,6 +101,13 @@ def llama_model_forward(
if batch_size >= 32 and kv_seq_len > 512:
use_cuda_kernel = False
# NOTE (yuanheng-zhao): fow now, only triton kernels support verification process
# during speculative-decoding (`q_len > 1`)
# We will expicitly disable `use_cuda_kernel` here when speculative-decoding is enabled
if inputmetadata.use_spec_dec and use_cuda_kernel:
use_cuda_kernel = False
logger.warning("CUDA kernel is disabled for speculative-decoding.")
hidden_states = self.embed_tokens(input_tokens_ids)
cu_seqlens = None
@ -415,6 +422,8 @@ class NopadLlamaAttention(LlamaAttention):
sm_scale=sm_scale,
)
else:
q_len = tokens_to_verify + 1 if is_verifier else 1
if use_cuda_kernel:
inference_ops.rotary_embedding_and_cache_copy(
query_states,
@ -429,7 +438,6 @@ class NopadLlamaAttention(LlamaAttention):
high_precision,
)
else:
q_len = tokens_to_verify + 1 if is_verifier else 1
if is_verifier:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
copy_k_to_blocked_cache(