diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py index a35158b72..8f857ff5d 100644 --- a/colossalai/kernel/cuda_native/__init__.py +++ b/colossalai/kernel/cuda_native/__init__.py @@ -1,3 +1,3 @@ from .layer_norm import MixedFusedLayerNorm as LayerNorm -from .scaled_softmax import FusedScaleMaskSoftmax from .multihead_attention import MultiHeadAttention +from .scaled_softmax import FusedScaleMaskSoftmax diff --git a/colossalai/kernel/cuda_native/flash_attention.py b/colossalai/kernel/cuda_native/flash_attention.py index 33380b8fc..2b86763f1 100644 --- a/colossalai/kernel/cuda_native/flash_attention.py +++ b/colossalai/kernel/cuda_native/flash_attention.py @@ -5,6 +5,7 @@ This is a Triton implementation of the Flash Attention algorithm (see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton) """ +import math import os import subprocess @@ -36,17 +37,17 @@ except ImportError: print('please install triton from https://github.com/openai/triton') HAS_TRITON = False try: + from flash_attn.flash_attention import FlashAttention from flash_attn.flash_attn_interface import ( flash_attn_unpadded_func, flash_attn_unpadded_kvpacked_func, - flash_attn_unpadded_qkvpacked_func, + flash_attn_unpadded_qkvpacked_func, ) HAS_FLASH_ATTN = True except ImportError: HAS_FLASH_ATTN = False print('please install flash_attn from https://github.com/HazyResearch/flash-attention') - if HAS_TRITON: @triton.jit @@ -409,6 +410,25 @@ if HAS_TRITON: if HAS_FLASH_ATTN: + from einops import rearrange + + class MaskedFlashAttention(torch.nn.Module): + + def __init__(self, num_attention_heads: int, attention_head_size: int, attention_dropout: float) -> None: + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_size = attention_head_size + self.attention_func = FlashAttention(softmax_scale=math.sqrt(attention_head_size), + attention_dropout=attention_dropout) + + def forward(self, query_key_value: torch.Tensor, attention_mask: torch.Tensor, causal=False): + if attention_mask.dtype is not torch.bool: + attention_mask = attention_mask.bool() + qkv = rearrange(query_key_value, 'b s (three h d) -> b s three h d', three=3, h=self.num_attention_heads) + context, _ = self.attention_func(qkv, key_padding_mask=attention_mask, causal=causal) + context = rearrange(context, 'b s h d -> b s (h d)') + return context + def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False): """ Arguments: @@ -423,15 +443,15 @@ if HAS_FLASH_ATTN: out: (total, nheads, headdim). """ max_s = seq_len - cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, - device=qkv.device) - out = flash_attn_unpadded_qkvpacked_func( - qkv, cu_seqlens, max_s, dropout_p, - softmax_scale=sm_scale, causal=causal - ) + cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=qkv.device) + out = flash_attn_unpadded_qkvpacked_func(qkv, + cu_seqlens, + max_s, + dropout_p, + softmax_scale=sm_scale, + causal=causal) return out - def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False): """ Arguments: @@ -447,19 +467,14 @@ if HAS_FLASH_ATTN: out: (total, nheads, headdim). """ cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device) - cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device) - out = flash_attn_unpadded_kvpacked_func(q, - kv, - cu_seqlens_q, - cu_seqlens_k, - q_seqlen, - kv_seqlen, - dropout_p, - sm_scale, - causal) + cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, + step=kv_seqlen, + dtype=torch.int32, + device=kv.device) + out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, dropout_p, + sm_scale, causal) return out - - + def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False): """ Arguments: @@ -476,14 +491,9 @@ if HAS_FLASH_ATTN: out: (total, nheads, headdim). """ cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device) - cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=k.device) - return flash_attn_unpadded_func(q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - q_seqlen, - kv_seqlen, - dropout_p, - sm_scale, + cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen, + step=kv_seqlen, + dtype=torch.int32, + device=k.device) + return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale, causal) diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index d2409fc62..9d2ee8a18 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -6,7 +6,11 @@ from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TR if HAS_FLASH_ATTN: from colossalai.kernel.cuda_native.flash_attention import ( - flash_attention_q_k_v, flash_attention_q_kv, flash_attention_qkv) + MaskedFlashAttention, + flash_attention_q_k_v, + flash_attention_q_kv, + flash_attention_qkv, + ) if HAS_TRITON: from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention @@ -87,17 +91,17 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): if i == 0: tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout) tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), - (tri_out, tri_dq, tri_dk, tri_dv)) + (tri_out, tri_dq, tri_dk, tri_dv)) elif i == 1: tri_dq, tri_dkv, = torch.autograd.grad(tri_out, (q, kv), dout) tri_dk, tri_dv = torch.chunk(tri_dkv, 2, dim=1) tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), - (tri_out, tri_dq, tri_dk.squeeze(1), tri_dv.squeeze(1))) + (tri_out, tri_dq, tri_dk.squeeze(1), tri_dv.squeeze(1))) else: tri_dqkv, = torch.autograd.grad(tri_out, (qkv), dout) tri_dq, tri_dk, tri_dv = torch.chunk(tri_dqkv, 3, dim=1) tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), - (tri_out, tri_dq.squeeze(1), tri_dk.squeeze(1), tri_dv.squeeze(1))) + (tri_out, tri_dq.squeeze(1), tri_dk.squeeze(1), tri_dv.squeeze(1))) # compare assert torch.allclose(ref_out, tri_out, atol=1e-3) @@ -106,5 +110,19 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): assert torch.allclose(ref_dq, tri_dq, atol=1e-3) +@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available") +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)]) +def test_masked_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): + attn = MaskedFlashAttention(N_CTX, D_HEAD, 0.1) + + qkv = torch.randn((Z, H, 3 * N_CTX * D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + attention_mask = torch.randint(2, (Z, H)).cuda().bool() + + out = attn(qkv, attention_mask) + + dout = torch.rand_like(out) + out.backward(dout) + + if __name__ == '__main__': test_flash_attention(3, 4, 2, 16)