mirror of https://github.com/hpcaitech/ColossalAI
[Fix] Llama Modeling Control with Spec-Dec (#5580)
- fix ref before asgmt - fall back to use triton kernels when using spec-decfeat/speculative-decoding
parent
e60d430cf5
commit
f8598e3ec5
|
@ -101,6 +101,13 @@ def llama_model_forward(
|
|||
if batch_size >= 32 and kv_seq_len > 512:
|
||||
use_cuda_kernel = False
|
||||
|
||||
# NOTE (yuanheng-zhao): fow now, only triton kernels support verification process
|
||||
# during speculative-decoding (`q_len > 1`)
|
||||
# We will expicitly disable `use_cuda_kernel` here when speculative-decoding is enabled
|
||||
if inputmetadata.use_spec_dec and use_cuda_kernel:
|
||||
use_cuda_kernel = False
|
||||
logger.warning("CUDA kernel is disabled for speculative-decoding.")
|
||||
|
||||
hidden_states = self.embed_tokens(input_tokens_ids)
|
||||
cu_seqlens = None
|
||||
|
||||
|
@ -415,6 +422,8 @@ class NopadLlamaAttention(LlamaAttention):
|
|||
sm_scale=sm_scale,
|
||||
)
|
||||
else:
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
|
||||
if use_cuda_kernel:
|
||||
inference_ops.rotary_embedding_and_cache_copy(
|
||||
query_states,
|
||||
|
@ -429,7 +438,6 @@ class NopadLlamaAttention(LlamaAttention):
|
|||
high_precision,
|
||||
)
|
||||
else:
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
if is_verifier:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
copy_k_to_blocked_cache(
|
||||
|
|
Loading…
Reference in New Issue