mirror of https://github.com/hpcaitech/ColossalAI
Remove flash attention backend
Signed-off-by: char-1ee <xingjianli59@gmail.com>pull/5771/head
parent
ceba662d22
commit
f5981e808e
|
@ -38,82 +38,51 @@ class AttentionBackend(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
Attention backend when use_cuda_kernel is True and flash-attn is installed. It uses
|
||||
`flash_attn_varlen_func` for prefilling and our cuda op `flash_decoding_attention` for decoding.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
token_nums = kwargs.get("token_nums", -1)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
cu_seqlens_q=attn_metadata.cu_seqlens,
|
||||
cu_seqlens_k=attn_metadata.cu_seqlens,
|
||||
max_seqlen_q=attn_metadata.kv_seq_len,
|
||||
max_seqlen_k=attn_metadata.kv_seq_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=attn_metadata.sm_scale,
|
||||
causal=True,
|
||||
alibi_slopes=attn_metadata.alibi_slopes,
|
||||
)
|
||||
attn_output = attn_output.view(token_nums, -1)
|
||||
return attn_output
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
|
||||
output_tensor = attn_metadata.output_tensor
|
||||
self.inference_ops.flash_decoding_attention(
|
||||
output_tensor,
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.block_size,
|
||||
attn_metadata.kv_seq_len,
|
||||
fd_inter_tensor.mid_output,
|
||||
fd_inter_tensor.exp_sums,
|
||||
fd_inter_tensor.max_logits,
|
||||
attn_metadata.alibi_slopes,
|
||||
attn_metadata.sm_scale,
|
||||
)
|
||||
return output_tensor
|
||||
|
||||
|
||||
class CudaAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
Attention backend when use_cuda_kernel is True but flash-attn not found. If flash-attn is not found,
|
||||
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, use_flash_attn: bool):
|
||||
super().__init__()
|
||||
self.inference_ops = InferenceOpsLoader().load()
|
||||
self.use_flash_attn = use_flash_attn
|
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
return context_attention_unpadded(
|
||||
q=attn_metadata.query_states,
|
||||
k=attn_metadata.key_states,
|
||||
v=attn_metadata.value_states,
|
||||
k_cache=attn_metadata.k_cache,
|
||||
v_cache=attn_metadata.v_cache,
|
||||
context_lengths=attn_metadata.sequence_lengths,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
block_size=attn_metadata.block_size,
|
||||
output=attn_metadata.output_tensor,
|
||||
alibi_slopes=attn_metadata.alibi_slopes,
|
||||
max_seq_len=attn_metadata.kv_seq_len,
|
||||
sm_scale=attn_metadata.sm_scale,
|
||||
use_new_kcache_layout=True, # use new k cache layout for cuda kernels in this triton op
|
||||
)
|
||||
if self.use_flash_attn:
|
||||
token_nums = kwargs.get("token_nums", -1)
|
||||
attn_output = flash_attn_varlen_func(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
cu_seqlens_q=attn_metadata.cu_seqlens,
|
||||
cu_seqlens_k=attn_metadata.cu_seqlens,
|
||||
max_seqlen_q=attn_metadata.kv_seq_len,
|
||||
max_seqlen_k=attn_metadata.kv_seq_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=attn_metadata.sm_scale,
|
||||
causal=True,
|
||||
alibi_slopes=attn_metadata.alibi_slopes,
|
||||
)
|
||||
attn_output = attn_output.view(token_nums, -1)
|
||||
else:
|
||||
attn_output = context_attention_unpadded(
|
||||
q=attn_metadata.query_states,
|
||||
k=attn_metadata.key_states,
|
||||
v=attn_metadata.value_states,
|
||||
k_cache=attn_metadata.k_cache,
|
||||
v_cache=attn_metadata.v_cache,
|
||||
context_lengths=attn_metadata.sequence_lengths,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
block_size=attn_metadata.block_size,
|
||||
output=attn_metadata.output_tensor,
|
||||
alibi_slopes=attn_metadata.alibi_slopes,
|
||||
max_seq_len=attn_metadata.kv_seq_len,
|
||||
sm_scale=attn_metadata.sm_scale,
|
||||
use_new_kcache_layout=True, # use new k-cache layout
|
||||
)
|
||||
return attn_output
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
|
||||
|
@ -155,7 +124,6 @@ class TritonAttentionBackend(AttentionBackend):
|
|||
alibi_slopes=attn_metadata.alibi_slopes,
|
||||
max_seq_len=attn_metadata.kv_seq_len,
|
||||
sm_scale=attn_metadata.sm_scale,
|
||||
use_new_kcache_layout=False,
|
||||
)
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
|
@ -195,8 +163,6 @@ def get_attention_backend(
|
|||
return TritonAttentionBackend()
|
||||
|
||||
if model_shard_infer_config.use_cuda_kernel:
|
||||
if model_shard_infer_config.use_flash_attn:
|
||||
return FlashAttentionBackend()
|
||||
return CudaAttentionBackend()
|
||||
return CudaAttentionBackend(model_shard_infer_config.use_flash_attn)
|
||||
|
||||
return TritonAttentionBackend()
|
||||
|
|
|
@ -16,71 +16,37 @@ class PreAttentionBackend(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class FlashPreAttentionBackend(PreAttentionBackend):
|
||||
"""
|
||||
FlashPreAttentionBackend handles KV cache initialization and positional encoding for FlashAttentionBackend.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
self.inference_ops.rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
kwargs.get("high_precision", False),
|
||||
)
|
||||
self.inference_ops.context_kv_cache_memcpy(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.cu_seqlens,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.kv_seq_len,
|
||||
)
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
self.inference_ops.rotary_embedding_and_cache_copy(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.block_tables,
|
||||
kwargs.get("high_precision", None),
|
||||
)
|
||||
else:
|
||||
self.inference_ops.decode_kv_cache_memcpy(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.block_tables,
|
||||
)
|
||||
|
||||
|
||||
class CudaPreAttentionBackend(PreAttentionBackend):
|
||||
"""
|
||||
CudaPreAttentionBackend handles KV cache initialization and positional encoding for CudaAttentionBackend.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, use_flash_attn: bool):
|
||||
super().__init__()
|
||||
self.inference_ops = InferenceOpsLoader().load()
|
||||
self.use_flash_attn = use_flash_attn
|
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
if self.use_flash_attn:
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
self.inference_ops.rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
kwargs.get("high_precision", False),
|
||||
)
|
||||
self.inference_ops.context_kv_cache_memcpy(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.cu_seqlens,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.kv_seq_len,
|
||||
)
|
||||
elif not attn_metadata.use_alibi_attn:
|
||||
rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
|
@ -175,8 +141,6 @@ def get_pre_attention_backend(
|
|||
return TritonPreAttentionBackend()
|
||||
|
||||
if model_shard_infer_config.use_cuda_kernel:
|
||||
if model_shard_infer_config.use_flash_attn:
|
||||
return FlashPreAttentionBackend()
|
||||
return CudaPreAttentionBackend()
|
||||
return CudaPreAttentionBackend(model_shard_infer_config.use_flash_attn)
|
||||
|
||||
return TritonPreAttentionBackend()
|
||||
|
|
Loading…
Reference in New Issue