mirror of https://github.com/hpcaitech/ColossalAI
Remove flash attention backend
Signed-off-by: char-1ee <xingjianli59@gmail.com>pull/5771/head
parent
ceba662d22
commit
f5981e808e
|
@ -38,82 +38,51 @@ class AttentionBackend(ABC):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionBackend(AttentionBackend):
|
|
||||||
"""
|
|
||||||
Attention backend when use_cuda_kernel is True and flash-attn is installed. It uses
|
|
||||||
`flash_attn_varlen_func` for prefilling and our cuda op `flash_decoding_attention` for decoding.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.inference_ops = InferenceOpsLoader().load()
|
|
||||||
|
|
||||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
|
||||||
token_nums = kwargs.get("token_nums", -1)
|
|
||||||
|
|
||||||
attn_output = flash_attn_varlen_func(
|
|
||||||
attn_metadata.query_states,
|
|
||||||
attn_metadata.key_states,
|
|
||||||
attn_metadata.value_states,
|
|
||||||
cu_seqlens_q=attn_metadata.cu_seqlens,
|
|
||||||
cu_seqlens_k=attn_metadata.cu_seqlens,
|
|
||||||
max_seqlen_q=attn_metadata.kv_seq_len,
|
|
||||||
max_seqlen_k=attn_metadata.kv_seq_len,
|
|
||||||
dropout_p=0.0,
|
|
||||||
softmax_scale=attn_metadata.sm_scale,
|
|
||||||
causal=True,
|
|
||||||
alibi_slopes=attn_metadata.alibi_slopes,
|
|
||||||
)
|
|
||||||
attn_output = attn_output.view(token_nums, -1)
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
|
||||||
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
|
|
||||||
output_tensor = attn_metadata.output_tensor
|
|
||||||
self.inference_ops.flash_decoding_attention(
|
|
||||||
output_tensor,
|
|
||||||
attn_metadata.query_states,
|
|
||||||
attn_metadata.k_cache,
|
|
||||||
attn_metadata.v_cache,
|
|
||||||
attn_metadata.sequence_lengths,
|
|
||||||
attn_metadata.block_tables,
|
|
||||||
attn_metadata.block_size,
|
|
||||||
attn_metadata.kv_seq_len,
|
|
||||||
fd_inter_tensor.mid_output,
|
|
||||||
fd_inter_tensor.exp_sums,
|
|
||||||
fd_inter_tensor.max_logits,
|
|
||||||
attn_metadata.alibi_slopes,
|
|
||||||
attn_metadata.sm_scale,
|
|
||||||
)
|
|
||||||
return output_tensor
|
|
||||||
|
|
||||||
|
|
||||||
class CudaAttentionBackend(AttentionBackend):
|
class CudaAttentionBackend(AttentionBackend):
|
||||||
"""
|
"""
|
||||||
Attention backend when use_cuda_kernel is True but flash-attn not found. If flash-attn is not found,
|
Attention backend when use_cuda_kernel is True but flash-attn not found. If flash-attn is not found,
|
||||||
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
|
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, use_flash_attn: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inference_ops = InferenceOpsLoader().load()
|
self.inference_ops = InferenceOpsLoader().load()
|
||||||
|
self.use_flash_attn = use_flash_attn
|
||||||
|
|
||||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||||
return context_attention_unpadded(
|
if self.use_flash_attn:
|
||||||
q=attn_metadata.query_states,
|
token_nums = kwargs.get("token_nums", -1)
|
||||||
k=attn_metadata.key_states,
|
attn_output = flash_attn_varlen_func(
|
||||||
v=attn_metadata.value_states,
|
attn_metadata.query_states,
|
||||||
k_cache=attn_metadata.k_cache,
|
attn_metadata.key_states,
|
||||||
v_cache=attn_metadata.v_cache,
|
attn_metadata.value_states,
|
||||||
context_lengths=attn_metadata.sequence_lengths,
|
cu_seqlens_q=attn_metadata.cu_seqlens,
|
||||||
block_tables=attn_metadata.block_tables,
|
cu_seqlens_k=attn_metadata.cu_seqlens,
|
||||||
block_size=attn_metadata.block_size,
|
max_seqlen_q=attn_metadata.kv_seq_len,
|
||||||
output=attn_metadata.output_tensor,
|
max_seqlen_k=attn_metadata.kv_seq_len,
|
||||||
alibi_slopes=attn_metadata.alibi_slopes,
|
dropout_p=0.0,
|
||||||
max_seq_len=attn_metadata.kv_seq_len,
|
softmax_scale=attn_metadata.sm_scale,
|
||||||
sm_scale=attn_metadata.sm_scale,
|
causal=True,
|
||||||
use_new_kcache_layout=True, # use new k cache layout for cuda kernels in this triton op
|
alibi_slopes=attn_metadata.alibi_slopes,
|
||||||
)
|
)
|
||||||
|
attn_output = attn_output.view(token_nums, -1)
|
||||||
|
else:
|
||||||
|
attn_output = context_attention_unpadded(
|
||||||
|
q=attn_metadata.query_states,
|
||||||
|
k=attn_metadata.key_states,
|
||||||
|
v=attn_metadata.value_states,
|
||||||
|
k_cache=attn_metadata.k_cache,
|
||||||
|
v_cache=attn_metadata.v_cache,
|
||||||
|
context_lengths=attn_metadata.sequence_lengths,
|
||||||
|
block_tables=attn_metadata.block_tables,
|
||||||
|
block_size=attn_metadata.block_size,
|
||||||
|
output=attn_metadata.output_tensor,
|
||||||
|
alibi_slopes=attn_metadata.alibi_slopes,
|
||||||
|
max_seq_len=attn_metadata.kv_seq_len,
|
||||||
|
sm_scale=attn_metadata.sm_scale,
|
||||||
|
use_new_kcache_layout=True, # use new k-cache layout
|
||||||
|
)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||||
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
|
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
|
||||||
|
@ -155,7 +124,6 @@ class TritonAttentionBackend(AttentionBackend):
|
||||||
alibi_slopes=attn_metadata.alibi_slopes,
|
alibi_slopes=attn_metadata.alibi_slopes,
|
||||||
max_seq_len=attn_metadata.kv_seq_len,
|
max_seq_len=attn_metadata.kv_seq_len,
|
||||||
sm_scale=attn_metadata.sm_scale,
|
sm_scale=attn_metadata.sm_scale,
|
||||||
use_new_kcache_layout=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||||
|
@ -195,8 +163,6 @@ def get_attention_backend(
|
||||||
return TritonAttentionBackend()
|
return TritonAttentionBackend()
|
||||||
|
|
||||||
if model_shard_infer_config.use_cuda_kernel:
|
if model_shard_infer_config.use_cuda_kernel:
|
||||||
if model_shard_infer_config.use_flash_attn:
|
return CudaAttentionBackend(model_shard_infer_config.use_flash_attn)
|
||||||
return FlashAttentionBackend()
|
|
||||||
return CudaAttentionBackend()
|
|
||||||
|
|
||||||
return TritonAttentionBackend()
|
return TritonAttentionBackend()
|
||||||
|
|
|
@ -16,71 +16,37 @@ class PreAttentionBackend(ABC):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class FlashPreAttentionBackend(PreAttentionBackend):
|
|
||||||
"""
|
|
||||||
FlashPreAttentionBackend handles KV cache initialization and positional encoding for FlashAttentionBackend.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.inference_ops = InferenceOpsLoader().load()
|
|
||||||
|
|
||||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
|
||||||
if not attn_metadata.use_alibi_attn:
|
|
||||||
self.inference_ops.rotary_embedding(
|
|
||||||
attn_metadata.query_states,
|
|
||||||
attn_metadata.key_states,
|
|
||||||
kwargs.get("cos", None),
|
|
||||||
kwargs.get("sin", None),
|
|
||||||
kwargs.get("high_precision", False),
|
|
||||||
)
|
|
||||||
self.inference_ops.context_kv_cache_memcpy(
|
|
||||||
attn_metadata.key_states,
|
|
||||||
attn_metadata.value_states,
|
|
||||||
attn_metadata.k_cache,
|
|
||||||
attn_metadata.v_cache,
|
|
||||||
attn_metadata.sequence_lengths,
|
|
||||||
attn_metadata.cu_seqlens,
|
|
||||||
attn_metadata.block_tables,
|
|
||||||
attn_metadata.kv_seq_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
|
||||||
if not attn_metadata.use_alibi_attn:
|
|
||||||
self.inference_ops.rotary_embedding_and_cache_copy(
|
|
||||||
attn_metadata.query_states,
|
|
||||||
attn_metadata.key_states,
|
|
||||||
attn_metadata.value_states,
|
|
||||||
kwargs.get("cos", None),
|
|
||||||
kwargs.get("sin", None),
|
|
||||||
attn_metadata.k_cache,
|
|
||||||
attn_metadata.v_cache,
|
|
||||||
attn_metadata.sequence_lengths,
|
|
||||||
attn_metadata.block_tables,
|
|
||||||
kwargs.get("high_precision", None),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.inference_ops.decode_kv_cache_memcpy(
|
|
||||||
attn_metadata.key_states,
|
|
||||||
attn_metadata.value_states,
|
|
||||||
attn_metadata.k_cache,
|
|
||||||
attn_metadata.v_cache,
|
|
||||||
attn_metadata.sequence_lengths,
|
|
||||||
attn_metadata.block_tables,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CudaPreAttentionBackend(PreAttentionBackend):
|
class CudaPreAttentionBackend(PreAttentionBackend):
|
||||||
"""
|
"""
|
||||||
CudaPreAttentionBackend handles KV cache initialization and positional encoding for CudaAttentionBackend.
|
CudaPreAttentionBackend handles KV cache initialization and positional encoding for CudaAttentionBackend.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, use_flash_attn: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inference_ops = InferenceOpsLoader().load()
|
self.inference_ops = InferenceOpsLoader().load()
|
||||||
|
self.use_flash_attn = use_flash_attn
|
||||||
|
|
||||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||||
if not attn_metadata.use_alibi_attn:
|
if self.use_flash_attn:
|
||||||
|
if not attn_metadata.use_alibi_attn:
|
||||||
|
self.inference_ops.rotary_embedding(
|
||||||
|
attn_metadata.query_states,
|
||||||
|
attn_metadata.key_states,
|
||||||
|
kwargs.get("cos", None),
|
||||||
|
kwargs.get("sin", None),
|
||||||
|
kwargs.get("high_precision", False),
|
||||||
|
)
|
||||||
|
self.inference_ops.context_kv_cache_memcpy(
|
||||||
|
attn_metadata.key_states,
|
||||||
|
attn_metadata.value_states,
|
||||||
|
attn_metadata.k_cache,
|
||||||
|
attn_metadata.v_cache,
|
||||||
|
attn_metadata.sequence_lengths,
|
||||||
|
attn_metadata.cu_seqlens,
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
attn_metadata.kv_seq_len,
|
||||||
|
)
|
||||||
|
elif not attn_metadata.use_alibi_attn:
|
||||||
rotary_embedding(
|
rotary_embedding(
|
||||||
attn_metadata.query_states,
|
attn_metadata.query_states,
|
||||||
attn_metadata.key_states,
|
attn_metadata.key_states,
|
||||||
|
@ -175,8 +141,6 @@ def get_pre_attention_backend(
|
||||||
return TritonPreAttentionBackend()
|
return TritonPreAttentionBackend()
|
||||||
|
|
||||||
if model_shard_infer_config.use_cuda_kernel:
|
if model_shard_infer_config.use_cuda_kernel:
|
||||||
if model_shard_infer_config.use_flash_attn:
|
return CudaPreAttentionBackend(model_shard_infer_config.use_flash_attn)
|
||||||
return FlashPreAttentionBackend()
|
|
||||||
return CudaPreAttentionBackend()
|
|
||||||
|
|
||||||
return TritonPreAttentionBackend()
|
return TritonPreAttentionBackend()
|
||||||
|
|
Loading…
Reference in New Issue