Browse Source

[Fix] Llama Modeling Control with Spec-Dec (#5580)

- fix ref before asgmt
- fall back to use triton kernels when using spec-dec
feat/speculative-decoding
Yuanheng 8 months ago committed by ocd_with_naming
parent
commit
f8598e3ec5
  1. 10
      colossalai/inference/modeling/models/nopadding_llama.py

10
colossalai/inference/modeling/models/nopadding_llama.py

@ -101,6 +101,13 @@ def llama_model_forward(
if batch_size >= 32 and kv_seq_len > 512: if batch_size >= 32 and kv_seq_len > 512:
use_cuda_kernel = False use_cuda_kernel = False
# NOTE (yuanheng-zhao): fow now, only triton kernels support verification process
# during speculative-decoding (`q_len > 1`)
# We will expicitly disable `use_cuda_kernel` here when speculative-decoding is enabled
if inputmetadata.use_spec_dec and use_cuda_kernel:
use_cuda_kernel = False
logger.warning("CUDA kernel is disabled for speculative-decoding.")
hidden_states = self.embed_tokens(input_tokens_ids) hidden_states = self.embed_tokens(input_tokens_ids)
cu_seqlens = None cu_seqlens = None
@ -415,6 +422,8 @@ class NopadLlamaAttention(LlamaAttention):
sm_scale=sm_scale, sm_scale=sm_scale,
) )
else: else:
q_len = tokens_to_verify + 1 if is_verifier else 1
if use_cuda_kernel: if use_cuda_kernel:
inference_ops.rotary_embedding_and_cache_copy( inference_ops.rotary_embedding_and_cache_copy(
query_states, query_states,
@ -429,7 +438,6 @@ class NopadLlamaAttention(LlamaAttention):
high_precision, high_precision,
) )
else: else:
q_len = tokens_to_verify + 1 if is_verifier else 1
if is_verifier: if is_verifier:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
copy_k_to_blocked_cache( copy_k_to_blocked_cache(

Loading…
Cancel
Save