ColossalAI/examples/images/diffusion/ldm/modules/flash_attention.py

51 lines
1.8 KiB
Python

"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
"""
import torch
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func
except ImportError:
raise ImportError('please install flash_attn from https://github.com/HazyResearch/flash-attention')
def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len):
"""
Arguments:
qkv: (batch*seq, 3, nheads, headdim)
batch_size: int.
seq_len: int.
sm_scale: float. The scaling of QK^T before applying softmax.
Return:
out: (total, nheads, headdim).
"""
max_s = seq_len
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32,
device=qkv.device)
out = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, 0.0,
softmax_scale=sm_scale, causal=False
)
return out
def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen):
"""
Arguments:
q: (batch*seq, nheads, headdim)
kv: (batch*seq, 2, nheads, headdim)
batch_size: int.
seq_len: int.
sm_scale: float. The scaling of QK^T before applying softmax.
Return:
out: (total, nheads, headdim).
"""
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device)
out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, 0.0, sm_scale)
return out