updated flash attention api

pull/1868/head^2
zbian 2022-11-14 17:11:33 +08:00 committed by アマデウス
parent 36c0f3ea5b
commit 6877121377
3 changed files with 64 additions and 36 deletions

View File

@ -1,3 +1,3 @@
from .layer_norm import MixedFusedLayerNorm as LayerNorm
from .scaled_softmax import FusedScaleMaskSoftmax
from .multihead_attention import MultiHeadAttention
from .scaled_softmax import FusedScaleMaskSoftmax

View File

@ -5,6 +5,7 @@ This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
"""
import math
import os
import subprocess
@ -36,17 +37,17 @@ except ImportError:
print('please install triton from https://github.com/openai/triton')
HAS_TRITON = False
try:
from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_func,
flash_attn_unpadded_kvpacked_func,
flash_attn_unpadded_qkvpacked_func,
flash_attn_unpadded_qkvpacked_func,
)
HAS_FLASH_ATTN = True
except ImportError:
HAS_FLASH_ATTN = False
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
if HAS_TRITON:
@triton.jit
@ -409,6 +410,25 @@ if HAS_TRITON:
if HAS_FLASH_ATTN:
from einops import rearrange
class MaskedFlashAttention(torch.nn.Module):
def __init__(self, num_attention_heads: int, attention_head_size: int, attention_dropout: float) -> None:
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_size = attention_head_size
self.attention_func = FlashAttention(softmax_scale=math.sqrt(attention_head_size),
attention_dropout=attention_dropout)
def forward(self, query_key_value: torch.Tensor, attention_mask: torch.Tensor, causal=False):
if attention_mask.dtype is not torch.bool:
attention_mask = attention_mask.bool()
qkv = rearrange(query_key_value, 'b s (three h d) -> b s three h d', three=3, h=self.num_attention_heads)
context, _ = self.attention_func(qkv, key_padding_mask=attention_mask, causal=causal)
context = rearrange(context, 'b s h d -> b s (h d)')
return context
def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False):
"""
Arguments:
@ -423,15 +443,15 @@ if HAS_FLASH_ATTN:
out: (total, nheads, headdim).
"""
max_s = seq_len
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32,
device=qkv.device)
out = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, dropout_p,
softmax_scale=sm_scale, causal=causal
)
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=qkv.device)
out = flash_attn_unpadded_qkvpacked_func(qkv,
cu_seqlens,
max_s,
dropout_p,
softmax_scale=sm_scale,
causal=causal)
return out
def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
"""
Arguments:
@ -447,19 +467,14 @@ if HAS_FLASH_ATTN:
out: (total, nheads, headdim).
"""
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device)
out = flash_attn_unpadded_kvpacked_func(q,
kv,
cu_seqlens_q,
cu_seqlens_k,
q_seqlen,
kv_seqlen,
dropout_p,
sm_scale,
causal)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen,
step=kv_seqlen,
dtype=torch.int32,
device=kv.device)
out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, dropout_p,
sm_scale, causal)
return out
def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
"""
Arguments:
@ -476,14 +491,9 @@ if HAS_FLASH_ATTN:
out: (total, nheads, headdim).
"""
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=k.device)
return flash_attn_unpadded_func(q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
q_seqlen,
kv_seqlen,
dropout_p,
sm_scale,
cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen,
step=kv_seqlen,
dtype=torch.int32,
device=k.device)
return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale,
causal)

View File

@ -6,7 +6,11 @@ from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TR
if HAS_FLASH_ATTN:
from colossalai.kernel.cuda_native.flash_attention import (
flash_attention_q_k_v, flash_attention_q_kv, flash_attention_qkv)
MaskedFlashAttention,
flash_attention_q_k_v,
flash_attention_q_kv,
flash_attention_qkv,
)
if HAS_TRITON:
from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention
@ -87,17 +91,17 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
if i == 0:
tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout)
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
(tri_out, tri_dq, tri_dk, tri_dv))
(tri_out, tri_dq, tri_dk, tri_dv))
elif i == 1:
tri_dq, tri_dkv, = torch.autograd.grad(tri_out, (q, kv), dout)
tri_dk, tri_dv = torch.chunk(tri_dkv, 2, dim=1)
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
(tri_out, tri_dq, tri_dk.squeeze(1), tri_dv.squeeze(1)))
(tri_out, tri_dq, tri_dk.squeeze(1), tri_dv.squeeze(1)))
else:
tri_dqkv, = torch.autograd.grad(tri_out, (qkv), dout)
tri_dq, tri_dk, tri_dv = torch.chunk(tri_dqkv, 3, dim=1)
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
(tri_out, tri_dq.squeeze(1), tri_dk.squeeze(1), tri_dv.squeeze(1)))
(tri_out, tri_dq.squeeze(1), tri_dk.squeeze(1), tri_dv.squeeze(1)))
# compare
assert torch.allclose(ref_out, tri_out, atol=1e-3)
@ -106,5 +110,19 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)])
def test_masked_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
attn = MaskedFlashAttention(N_CTX, D_HEAD, 0.1)
qkv = torch.randn((Z, H, 3 * N_CTX * D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
attention_mask = torch.randint(2, (Z, H)).cuda().bool()
out = attn(qkv, attention_mask)
dout = torch.rand_like(out)
out.backward(dout)
if __name__ == '__main__':
test_flash_attention(3, 4, 2, 16)