From f5981e808e5ef226504210753e849fe4db0f26fa Mon Sep 17 00:00:00 2001 From: char-1ee Date: Fri, 7 Jun 2024 10:02:19 +0000 Subject: [PATCH] Remove flash attention backend Signed-off-by: char-1ee --- .../modeling/backends/attention_backend.py | 106 ++++++------------ .../backends/pre_attention_backend.py | 82 ++++---------- 2 files changed, 59 insertions(+), 129 deletions(-) diff --git a/colossalai/inference/modeling/backends/attention_backend.py b/colossalai/inference/modeling/backends/attention_backend.py index ed0ccda8a..e0a4ec33d 100644 --- a/colossalai/inference/modeling/backends/attention_backend.py +++ b/colossalai/inference/modeling/backends/attention_backend.py @@ -38,82 +38,51 @@ class AttentionBackend(ABC): raise NotImplementedError -class FlashAttentionBackend(AttentionBackend): - """ - Attention backend when use_cuda_kernel is True and flash-attn is installed. It uses - `flash_attn_varlen_func` for prefilling and our cuda op `flash_decoding_attention` for decoding. - """ - - def __init__(self): - super().__init__() - self.inference_ops = InferenceOpsLoader().load() - - def prefill(self, attn_metadata: AttentionMetaData, **kwargs): - token_nums = kwargs.get("token_nums", -1) - - attn_output = flash_attn_varlen_func( - attn_metadata.query_states, - attn_metadata.key_states, - attn_metadata.value_states, - cu_seqlens_q=attn_metadata.cu_seqlens, - cu_seqlens_k=attn_metadata.cu_seqlens, - max_seqlen_q=attn_metadata.kv_seq_len, - max_seqlen_k=attn_metadata.kv_seq_len, - dropout_p=0.0, - softmax_scale=attn_metadata.sm_scale, - causal=True, - alibi_slopes=attn_metadata.alibi_slopes, - ) - attn_output = attn_output.view(token_nums, -1) - return attn_output - - def decode(self, attn_metadata: AttentionMetaData, **kwargs): - fd_inter_tensor = kwargs.get("fd_inter_tensor", None) - output_tensor = attn_metadata.output_tensor - self.inference_ops.flash_decoding_attention( - output_tensor, - attn_metadata.query_states, - attn_metadata.k_cache, - attn_metadata.v_cache, - attn_metadata.sequence_lengths, - attn_metadata.block_tables, - attn_metadata.block_size, - attn_metadata.kv_seq_len, - fd_inter_tensor.mid_output, - fd_inter_tensor.exp_sums, - fd_inter_tensor.max_logits, - attn_metadata.alibi_slopes, - attn_metadata.sm_scale, - ) - return output_tensor - - class CudaAttentionBackend(AttentionBackend): """ Attention backend when use_cuda_kernel is True but flash-attn not found. If flash-attn is not found, it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding. """ - def __init__(self): + def __init__(self, use_flash_attn: bool): super().__init__() self.inference_ops = InferenceOpsLoader().load() + self.use_flash_attn = use_flash_attn def prefill(self, attn_metadata: AttentionMetaData, **kwargs): - return context_attention_unpadded( - q=attn_metadata.query_states, - k=attn_metadata.key_states, - v=attn_metadata.value_states, - k_cache=attn_metadata.k_cache, - v_cache=attn_metadata.v_cache, - context_lengths=attn_metadata.sequence_lengths, - block_tables=attn_metadata.block_tables, - block_size=attn_metadata.block_size, - output=attn_metadata.output_tensor, - alibi_slopes=attn_metadata.alibi_slopes, - max_seq_len=attn_metadata.kv_seq_len, - sm_scale=attn_metadata.sm_scale, - use_new_kcache_layout=True, # use new k cache layout for cuda kernels in this triton op - ) + if self.use_flash_attn: + token_nums = kwargs.get("token_nums", -1) + attn_output = flash_attn_varlen_func( + attn_metadata.query_states, + attn_metadata.key_states, + attn_metadata.value_states, + cu_seqlens_q=attn_metadata.cu_seqlens, + cu_seqlens_k=attn_metadata.cu_seqlens, + max_seqlen_q=attn_metadata.kv_seq_len, + max_seqlen_k=attn_metadata.kv_seq_len, + dropout_p=0.0, + softmax_scale=attn_metadata.sm_scale, + causal=True, + alibi_slopes=attn_metadata.alibi_slopes, + ) + attn_output = attn_output.view(token_nums, -1) + else: + attn_output = context_attention_unpadded( + q=attn_metadata.query_states, + k=attn_metadata.key_states, + v=attn_metadata.value_states, + k_cache=attn_metadata.k_cache, + v_cache=attn_metadata.v_cache, + context_lengths=attn_metadata.sequence_lengths, + block_tables=attn_metadata.block_tables, + block_size=attn_metadata.block_size, + output=attn_metadata.output_tensor, + alibi_slopes=attn_metadata.alibi_slopes, + max_seq_len=attn_metadata.kv_seq_len, + sm_scale=attn_metadata.sm_scale, + use_new_kcache_layout=True, # use new k-cache layout + ) + return attn_output def decode(self, attn_metadata: AttentionMetaData, **kwargs): fd_inter_tensor = kwargs.get("fd_inter_tensor", None) @@ -155,7 +124,6 @@ class TritonAttentionBackend(AttentionBackend): alibi_slopes=attn_metadata.alibi_slopes, max_seq_len=attn_metadata.kv_seq_len, sm_scale=attn_metadata.sm_scale, - use_new_kcache_layout=False, ) def decode(self, attn_metadata: AttentionMetaData, **kwargs): @@ -195,8 +163,6 @@ def get_attention_backend( return TritonAttentionBackend() if model_shard_infer_config.use_cuda_kernel: - if model_shard_infer_config.use_flash_attn: - return FlashAttentionBackend() - return CudaAttentionBackend() + return CudaAttentionBackend(model_shard_infer_config.use_flash_attn) return TritonAttentionBackend() diff --git a/colossalai/inference/modeling/backends/pre_attention_backend.py b/colossalai/inference/modeling/backends/pre_attention_backend.py index d8911cb23..77804429d 100644 --- a/colossalai/inference/modeling/backends/pre_attention_backend.py +++ b/colossalai/inference/modeling/backends/pre_attention_backend.py @@ -16,71 +16,37 @@ class PreAttentionBackend(ABC): raise NotImplementedError -class FlashPreAttentionBackend(PreAttentionBackend): - """ - FlashPreAttentionBackend handles KV cache initialization and positional encoding for FlashAttentionBackend. - """ - - def __init__(self): - super().__init__() - self.inference_ops = InferenceOpsLoader().load() - - def prefill(self, attn_metadata: AttentionMetaData, **kwargs): - if not attn_metadata.use_alibi_attn: - self.inference_ops.rotary_embedding( - attn_metadata.query_states, - attn_metadata.key_states, - kwargs.get("cos", None), - kwargs.get("sin", None), - kwargs.get("high_precision", False), - ) - self.inference_ops.context_kv_cache_memcpy( - attn_metadata.key_states, - attn_metadata.value_states, - attn_metadata.k_cache, - attn_metadata.v_cache, - attn_metadata.sequence_lengths, - attn_metadata.cu_seqlens, - attn_metadata.block_tables, - attn_metadata.kv_seq_len, - ) - - def decode(self, attn_metadata: AttentionMetaData, **kwargs): - if not attn_metadata.use_alibi_attn: - self.inference_ops.rotary_embedding_and_cache_copy( - attn_metadata.query_states, - attn_metadata.key_states, - attn_metadata.value_states, - kwargs.get("cos", None), - kwargs.get("sin", None), - attn_metadata.k_cache, - attn_metadata.v_cache, - attn_metadata.sequence_lengths, - attn_metadata.block_tables, - kwargs.get("high_precision", None), - ) - else: - self.inference_ops.decode_kv_cache_memcpy( - attn_metadata.key_states, - attn_metadata.value_states, - attn_metadata.k_cache, - attn_metadata.v_cache, - attn_metadata.sequence_lengths, - attn_metadata.block_tables, - ) - - class CudaPreAttentionBackend(PreAttentionBackend): """ CudaPreAttentionBackend handles KV cache initialization and positional encoding for CudaAttentionBackend. """ - def __init__(self): + def __init__(self, use_flash_attn: bool): super().__init__() self.inference_ops = InferenceOpsLoader().load() + self.use_flash_attn = use_flash_attn def prefill(self, attn_metadata: AttentionMetaData, **kwargs): - if not attn_metadata.use_alibi_attn: + if self.use_flash_attn: + if not attn_metadata.use_alibi_attn: + self.inference_ops.rotary_embedding( + attn_metadata.query_states, + attn_metadata.key_states, + kwargs.get("cos", None), + kwargs.get("sin", None), + kwargs.get("high_precision", False), + ) + self.inference_ops.context_kv_cache_memcpy( + attn_metadata.key_states, + attn_metadata.value_states, + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.sequence_lengths, + attn_metadata.cu_seqlens, + attn_metadata.block_tables, + attn_metadata.kv_seq_len, + ) + elif not attn_metadata.use_alibi_attn: rotary_embedding( attn_metadata.query_states, attn_metadata.key_states, @@ -175,8 +141,6 @@ def get_pre_attention_backend( return TritonPreAttentionBackend() if model_shard_infer_config.use_cuda_kernel: - if model_shard_infer_config.use_flash_attn: - return FlashPreAttentionBackend() - return CudaPreAttentionBackend() + return CudaPreAttentionBackend(model_shard_infer_config.use_flash_attn) return TritonPreAttentionBackend()