mirror of https://github.com/hpcaitech/ColossalAI
51 lines
1.8 KiB
Python
51 lines
1.8 KiB
Python
"""
|
|
Fused Attention
|
|
===============
|
|
This is a Triton implementation of the Flash Attention algorithm
|
|
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
|
|
"""
|
|
|
|
import torch
|
|
try:
|
|
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func
|
|
except ImportError:
|
|
raise ImportError('please install flash_attn from https://github.com/HazyResearch/flash-attention')
|
|
|
|
|
|
|
|
def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len):
|
|
"""
|
|
Arguments:
|
|
qkv: (batch*seq, 3, nheads, headdim)
|
|
batch_size: int.
|
|
seq_len: int.
|
|
sm_scale: float. The scaling of QK^T before applying softmax.
|
|
Return:
|
|
out: (total, nheads, headdim).
|
|
"""
|
|
max_s = seq_len
|
|
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32,
|
|
device=qkv.device)
|
|
out = flash_attn_unpadded_qkvpacked_func(
|
|
qkv, cu_seqlens, max_s, 0.0,
|
|
softmax_scale=sm_scale, causal=False
|
|
)
|
|
return out
|
|
|
|
|
|
def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen):
|
|
"""
|
|
Arguments:
|
|
q: (batch*seq, nheads, headdim)
|
|
kv: (batch*seq, 2, nheads, headdim)
|
|
batch_size: int.
|
|
seq_len: int.
|
|
sm_scale: float. The scaling of QK^T before applying softmax.
|
|
Return:
|
|
out: (total, nheads, headdim).
|
|
"""
|
|
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
|
|
cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device)
|
|
out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, 0.0, sm_scale)
|
|
return out
|