mirror of https://github.com/hpcaitech/ColossalAI
147 lines
5.6 KiB
Python
147 lines
5.6 KiB
Python
from abc import ABC, abstractmethod
|
|
|
|
from colossalai.inference.config import ModelShardInferenceConfig
|
|
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData
|
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
|
from colossalai.kernel.triton import copy_k_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
|
|
|
|
|
|
class PreAttentionBackend(ABC):
|
|
@abstractmethod
|
|
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
|
|
class CudaPreAttentionBackend(PreAttentionBackend):
|
|
"""
|
|
CudaPreAttentionBackend handles KV cache initialization and positional encoding for CudaAttentionBackend.
|
|
"""
|
|
|
|
def __init__(self, use_flash_attn: bool):
|
|
super().__init__()
|
|
self.inference_ops = InferenceOpsLoader().load()
|
|
self.use_flash_attn = use_flash_attn
|
|
|
|
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
|
if self.use_flash_attn:
|
|
if not attn_metadata.use_alibi_attn:
|
|
self.inference_ops.rotary_embedding(
|
|
attn_metadata.query_states,
|
|
attn_metadata.key_states,
|
|
kwargs.get("cos", None),
|
|
kwargs.get("sin", None),
|
|
kwargs.get("high_precision", False),
|
|
)
|
|
self.inference_ops.context_kv_cache_memcpy(
|
|
attn_metadata.key_states,
|
|
attn_metadata.value_states,
|
|
attn_metadata.k_cache,
|
|
attn_metadata.v_cache,
|
|
attn_metadata.sequence_lengths,
|
|
attn_metadata.cu_seqlens,
|
|
attn_metadata.block_tables,
|
|
attn_metadata.kv_seq_len,
|
|
)
|
|
elif not attn_metadata.use_alibi_attn:
|
|
rotary_embedding(
|
|
attn_metadata.query_states,
|
|
attn_metadata.key_states,
|
|
kwargs.get("cos", None),
|
|
kwargs.get("sin", None),
|
|
)
|
|
|
|
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
|
if not attn_metadata.use_alibi_attn:
|
|
self.inference_ops.rotary_embedding_and_cache_copy(
|
|
attn_metadata.query_states,
|
|
attn_metadata.key_states,
|
|
attn_metadata.value_states,
|
|
kwargs.get("cos", None),
|
|
kwargs.get("sin", None),
|
|
attn_metadata.k_cache,
|
|
attn_metadata.v_cache,
|
|
attn_metadata.sequence_lengths,
|
|
attn_metadata.block_tables,
|
|
kwargs.get("high_precision", None),
|
|
)
|
|
else:
|
|
self.inference_ops.decode_kv_cache_memcpy(
|
|
attn_metadata.key_states,
|
|
attn_metadata.value_states,
|
|
attn_metadata.k_cache,
|
|
attn_metadata.v_cache,
|
|
attn_metadata.sequence_lengths,
|
|
attn_metadata.block_tables,
|
|
)
|
|
|
|
|
|
class TritonPreAttentionBackend(PreAttentionBackend):
|
|
"""
|
|
TritonPreAttentionBackend handles KV cache initialization and positional encoding for TritonAttentionBackend.
|
|
"""
|
|
|
|
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
|
if not attn_metadata.use_alibi_attn:
|
|
rotary_embedding(
|
|
attn_metadata.query_states,
|
|
attn_metadata.key_states,
|
|
kwargs.get("cos", None),
|
|
kwargs.get("sin", None),
|
|
)
|
|
|
|
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
|
if not attn_metadata.use_spec_dec and not attn_metadata.use_alibi_attn:
|
|
decoding_fused_rotary_embedding(
|
|
attn_metadata.query_states,
|
|
attn_metadata.key_states,
|
|
attn_metadata.value_states,
|
|
kwargs.get("cos", None),
|
|
kwargs.get("sin", None),
|
|
attn_metadata.k_cache,
|
|
attn_metadata.v_cache,
|
|
attn_metadata.block_tables,
|
|
attn_metadata.sequence_lengths,
|
|
)
|
|
else: # else if using speculative decoding
|
|
if not attn_metadata.use_alibi_attn:
|
|
rotary_embedding(
|
|
attn_metadata.query_states,
|
|
attn_metadata.key_states,
|
|
kwargs.get("cos", None),
|
|
kwargs.get("sin", None),
|
|
)
|
|
copy_k_to_blocked_cache(
|
|
attn_metadata.key_states,
|
|
attn_metadata.k_cache,
|
|
kv_lengths=attn_metadata.sequence_lengths,
|
|
block_tables=attn_metadata.block_tables,
|
|
n=kwargs.get("q_len", 1),
|
|
)
|
|
copy_k_to_blocked_cache(
|
|
attn_metadata.value_states,
|
|
attn_metadata.v_cache,
|
|
kv_lengths=attn_metadata.sequence_lengths,
|
|
block_tables=attn_metadata.block_tables,
|
|
n=kwargs.get("q_len", 1),
|
|
)
|
|
|
|
|
|
def get_pre_attention_backend(
|
|
model_shard_infer_config: ModelShardInferenceConfig,
|
|
) -> PreAttentionBackend:
|
|
"""
|
|
Get the backend for pre-attention computations, including potisional encoding like
|
|
RoPE and KV cache initialization. It adopt the same selection logic as attention_backend/get_attention_backend.
|
|
"""
|
|
if model_shard_infer_config.use_spec_dec:
|
|
return TritonPreAttentionBackend()
|
|
|
|
if model_shard_infer_config.use_cuda_kernel:
|
|
return CudaPreAttentionBackend(model_shard_infer_config.use_flash_attn)
|
|
|
|
return TritonPreAttentionBackend()
|