From f8598e3ec56bbe6bc6dd9fd84a1e0543adbd3073 Mon Sep 17 00:00:00 2001 From: Yuanheng Date: Wed, 10 Apr 2024 11:14:04 +0800 Subject: [PATCH] [Fix] Llama Modeling Control with Spec-Dec (#5580) - fix ref before asgmt - fall back to use triton kernels when using spec-dec --- .../inference/modeling/models/nopadding_llama.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 1f0008b97..2b14190da 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -101,6 +101,13 @@ def llama_model_forward( if batch_size >= 32 and kv_seq_len > 512: use_cuda_kernel = False + # NOTE (yuanheng-zhao): fow now, only triton kernels support verification process + # during speculative-decoding (`q_len > 1`) + # We will expicitly disable `use_cuda_kernel` here when speculative-decoding is enabled + if inputmetadata.use_spec_dec and use_cuda_kernel: + use_cuda_kernel = False + logger.warning("CUDA kernel is disabled for speculative-decoding.") + hidden_states = self.embed_tokens(input_tokens_ids) cu_seqlens = None @@ -415,6 +422,8 @@ class NopadLlamaAttention(LlamaAttention): sm_scale=sm_scale, ) else: + q_len = tokens_to_verify + 1 if is_verifier else 1 + if use_cuda_kernel: inference_ops.rotary_embedding_and_cache_copy( query_states, @@ -429,7 +438,6 @@ class NopadLlamaAttention(LlamaAttention): high_precision, ) else: - q_len = tokens_to_verify + 1 if is_verifier else 1 if is_verifier: rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) copy_k_to_blocked_cache(