ColossalAI/colossalai/inference/modeling/backends/pre_attention_backend.py

147 lines
5.6 KiB
Python

from abc import ABC, abstractmethod
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import copy_k_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
class PreAttentionBackend(ABC):
@abstractmethod
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
raise NotImplementedError
@abstractmethod
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
raise NotImplementedError
class CudaPreAttentionBackend(PreAttentionBackend):
"""
CudaPreAttentionBackend handles KV cache initialization and positional encoding for CudaAttentionBackend.
"""
def __init__(self, use_flash_attn: bool):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()
self.use_flash_attn = use_flash_attn
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if self.use_flash_attn:
if not attn_metadata.use_alibi_attn:
self.inference_ops.rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
kwargs.get("high_precision", False),
)
self.inference_ops.context_kv_cache_memcpy(
attn_metadata.key_states,
attn_metadata.value_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.cu_seqlens,
attn_metadata.block_tables,
attn_metadata.kv_seq_len,
)
elif not attn_metadata.use_alibi_attn:
rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
)
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
self.inference_ops.rotary_embedding_and_cache_copy(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
kwargs.get("high_precision", None),
)
else:
self.inference_ops.decode_kv_cache_memcpy(
attn_metadata.key_states,
attn_metadata.value_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
)
class TritonPreAttentionBackend(PreAttentionBackend):
"""
TritonPreAttentionBackend handles KV cache initialization and positional encoding for TritonAttentionBackend.
"""
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
)
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_spec_dec and not attn_metadata.use_alibi_attn:
decoding_fused_rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.block_tables,
attn_metadata.sequence_lengths,
)
else: # else if using speculative decoding
if not attn_metadata.use_alibi_attn:
rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
)
copy_k_to_blocked_cache(
attn_metadata.key_states,
attn_metadata.k_cache,
kv_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
n=kwargs.get("q_len", 1),
)
copy_k_to_blocked_cache(
attn_metadata.value_states,
attn_metadata.v_cache,
kv_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
n=kwargs.get("q_len", 1),
)
def get_pre_attention_backend(
model_shard_infer_config: ModelShardInferenceConfig,
) -> PreAttentionBackend:
"""
Get the backend for pre-attention computations, including potisional encoding like
RoPE and KV cache initialization. It adopt the same selection logic as attention_backend/get_attention_backend.
"""
if model_shard_infer_config.use_spec_dec:
return TritonPreAttentionBackend()
if model_shard_infer_config.use_cuda_kernel:
return CudaPreAttentionBackend(model_shard_infer_config.use_flash_attn)
return TritonPreAttentionBackend()