ColossalAI/colossalai/inference/modeling/backends/pre_attention_backend.py

132 lines
4.8 KiB
Python

from abc import ABC, abstractmethod
import torch
from colossalai.inference.utils import can_use_flash_attn2
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData
from colossalai.logging import get_dist_logger
from colossalai.kernel.triton import (
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding,
rotary_embedding,
)
logger = get_dist_logger(__name__)
inference_ops = InferenceOpsLoader().load()
class PreAttentionBackend(ABC):
@abstractmethod
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
raise NotImplementedError
@abstractmethod
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
raise NotImplementedError
class CudaPreAttentionBackend(PreAttentionBackend):
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
inference_ops.rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
kwargs.get("high_precision", False),
)
inference_ops.context_kv_cache_memcpy(
attn_metadata.key_states,
attn_metadata.value_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.cu_seqlens,
attn_metadata.block_tables,
attn_metadata.kv_seq_len,
)
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
inference_ops.rotary_embedding_and_cache_copy(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
kwargs.get("high_precision", None),
)
else:
inference_ops.decode_kv_cache_memcpy(
attn_metadata.key_states,
attn_metadata.value_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
)
class TritonPreAttentionBackend(PreAttentionBackend):
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
)
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_spec_dec and not attn_metadata.use_alibi_attn:
decoding_fused_rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.block_tables,
attn_metadata.sequence_lengths,
)
else: # else if using speculative decoding
if not attn_metadata.use_alibi_attn:
rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
)
copy_k_to_blocked_cache(
attn_metadata.key_states,
attn_metadata.k_cache,
kv_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
n=kwargs.get("q_len", 1),
)
copy_k_to_blocked_cache(
attn_metadata.value_states,
attn_metadata.v_cache,
kv_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
n=kwargs.get("q_len", 1),
)
def get_pre_attention_backend(
use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype
) -> PreAttentionBackend:
"""
Get the backend for pre-attention computations, including potisional encoding like RoPE and KV cache initialization.
"""
use_flash_attn = can_use_flash_attn2(dtype)
if use_cuda_kernel and use_flash_attn and not use_spec_dec:
return CudaPreAttentionBackend()
else:
return TritonPreAttentionBackend()