mirror of https://github.com/hpcaitech/ColossalAI
updated flash attention api
parent
36c0f3ea5b
commit
6877121377
|
@ -1,3 +1,3 @@
|
|||
from .layer_norm import MixedFusedLayerNorm as LayerNorm
|
||||
from .scaled_softmax import FusedScaleMaskSoftmax
|
||||
from .multihead_attention import MultiHeadAttention
|
||||
from .scaled_softmax import FusedScaleMaskSoftmax
|
||||
|
|
|
@ -5,6 +5,7 @@ This is a Triton implementation of the Flash Attention algorithm
|
|||
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
|
@ -36,17 +37,17 @@ except ImportError:
|
|||
print('please install triton from https://github.com/openai/triton')
|
||||
HAS_TRITON = False
|
||||
try:
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_unpadded_func,
|
||||
flash_attn_unpadded_kvpacked_func,
|
||||
flash_attn_unpadded_qkvpacked_func,
|
||||
flash_attn_unpadded_qkvpacked_func,
|
||||
)
|
||||
HAS_FLASH_ATTN = True
|
||||
except ImportError:
|
||||
HAS_FLASH_ATTN = False
|
||||
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
|
||||
|
||||
|
||||
if HAS_TRITON:
|
||||
|
||||
@triton.jit
|
||||
|
@ -409,6 +410,25 @@ if HAS_TRITON:
|
|||
|
||||
if HAS_FLASH_ATTN:
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
class MaskedFlashAttention(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_attention_heads: int, attention_head_size: int, attention_dropout: float) -> None:
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_size = attention_head_size
|
||||
self.attention_func = FlashAttention(softmax_scale=math.sqrt(attention_head_size),
|
||||
attention_dropout=attention_dropout)
|
||||
|
||||
def forward(self, query_key_value: torch.Tensor, attention_mask: torch.Tensor, causal=False):
|
||||
if attention_mask.dtype is not torch.bool:
|
||||
attention_mask = attention_mask.bool()
|
||||
qkv = rearrange(query_key_value, 'b s (three h d) -> b s three h d', three=3, h=self.num_attention_heads)
|
||||
context, _ = self.attention_func(qkv, key_padding_mask=attention_mask, causal=causal)
|
||||
context = rearrange(context, 'b s h d -> b s (h d)')
|
||||
return context
|
||||
|
||||
def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False):
|
||||
"""
|
||||
Arguments:
|
||||
|
@ -423,15 +443,15 @@ if HAS_FLASH_ATTN:
|
|||
out: (total, nheads, headdim).
|
||||
"""
|
||||
max_s = seq_len
|
||||
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32,
|
||||
device=qkv.device)
|
||||
out = flash_attn_unpadded_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_s, dropout_p,
|
||||
softmax_scale=sm_scale, causal=causal
|
||||
)
|
||||
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=qkv.device)
|
||||
out = flash_attn_unpadded_qkvpacked_func(qkv,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
dropout_p,
|
||||
softmax_scale=sm_scale,
|
||||
causal=causal)
|
||||
return out
|
||||
|
||||
|
||||
def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
|
||||
"""
|
||||
Arguments:
|
||||
|
@ -447,19 +467,14 @@ if HAS_FLASH_ATTN:
|
|||
out: (total, nheads, headdim).
|
||||
"""
|
||||
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device)
|
||||
out = flash_attn_unpadded_kvpacked_func(q,
|
||||
kv,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
q_seqlen,
|
||||
kv_seqlen,
|
||||
dropout_p,
|
||||
sm_scale,
|
||||
causal)
|
||||
cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen,
|
||||
step=kv_seqlen,
|
||||
dtype=torch.int32,
|
||||
device=kv.device)
|
||||
out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, dropout_p,
|
||||
sm_scale, causal)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
|
||||
"""
|
||||
Arguments:
|
||||
|
@ -476,14 +491,9 @@ if HAS_FLASH_ATTN:
|
|||
out: (total, nheads, headdim).
|
||||
"""
|
||||
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=k.device)
|
||||
return flash_attn_unpadded_func(q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_kv,
|
||||
q_seqlen,
|
||||
kv_seqlen,
|
||||
dropout_p,
|
||||
sm_scale,
|
||||
cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen,
|
||||
step=kv_seqlen,
|
||||
dtype=torch.int32,
|
||||
device=k.device)
|
||||
return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale,
|
||||
causal)
|
||||
|
|
|
@ -6,7 +6,11 @@ from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TR
|
|||
|
||||
if HAS_FLASH_ATTN:
|
||||
from colossalai.kernel.cuda_native.flash_attention import (
|
||||
flash_attention_q_k_v, flash_attention_q_kv, flash_attention_qkv)
|
||||
MaskedFlashAttention,
|
||||
flash_attention_q_k_v,
|
||||
flash_attention_q_kv,
|
||||
flash_attention_qkv,
|
||||
)
|
||||
|
||||
if HAS_TRITON:
|
||||
from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention
|
||||
|
@ -87,17 +91,17 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
|||
if i == 0:
|
||||
tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout)
|
||||
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
|
||||
(tri_out, tri_dq, tri_dk, tri_dv))
|
||||
(tri_out, tri_dq, tri_dk, tri_dv))
|
||||
elif i == 1:
|
||||
tri_dq, tri_dkv, = torch.autograd.grad(tri_out, (q, kv), dout)
|
||||
tri_dk, tri_dv = torch.chunk(tri_dkv, 2, dim=1)
|
||||
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
|
||||
(tri_out, tri_dq, tri_dk.squeeze(1), tri_dv.squeeze(1)))
|
||||
(tri_out, tri_dq, tri_dk.squeeze(1), tri_dv.squeeze(1)))
|
||||
else:
|
||||
tri_dqkv, = torch.autograd.grad(tri_out, (qkv), dout)
|
||||
tri_dq, tri_dk, tri_dv = torch.chunk(tri_dqkv, 3, dim=1)
|
||||
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
|
||||
(tri_out, tri_dq.squeeze(1), tri_dk.squeeze(1), tri_dv.squeeze(1)))
|
||||
(tri_out, tri_dq.squeeze(1), tri_dk.squeeze(1), tri_dv.squeeze(1)))
|
||||
|
||||
# compare
|
||||
assert torch.allclose(ref_out, tri_out, atol=1e-3)
|
||||
|
@ -106,5 +110,19 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
|||
assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)])
|
||||
def test_masked_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
attn = MaskedFlashAttention(N_CTX, D_HEAD, 0.1)
|
||||
|
||||
qkv = torch.randn((Z, H, 3 * N_CTX * D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
attention_mask = torch.randint(2, (Z, H)).cuda().bool()
|
||||
|
||||
out = attn(qkv, attention_mask)
|
||||
|
||||
dout = torch.rand_like(out)
|
||||
out.backward(dout)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_flash_attention(3, 4, 2, 16)
|
||||
|
|
Loading…
Reference in New Issue