From 077a66dd819e372f8eb55cd24fc024cac994065f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=82=A2=E3=83=9E=E3=83=87=E3=82=A6=E3=82=B9?= Date: Fri, 16 Dec 2022 10:54:03 +0800 Subject: [PATCH] updated attention kernel (#2133) --- .../kernel/cuda_native/flash_attention.py | 26 +++++++++++++++++++ tests/test_utils/test_flash_attention.py | 20 +++++++++++++- 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/colossalai/kernel/cuda_native/flash_attention.py b/colossalai/kernel/cuda_native/flash_attention.py index 2b86763f1..7bd646d39 100644 --- a/colossalai/kernel/cuda_native/flash_attention.py +++ b/colossalai/kernel/cuda_native/flash_attention.py @@ -48,6 +48,13 @@ except ImportError: HAS_FLASH_ATTN = False print('please install flash_attn from https://github.com/HazyResearch/flash-attention') +try: + from xformers.ops.fmha import memory_efficient_attention + HAS_MEM_EFF_ATTN = True +except ImportError: + HAS_MEM_EFF_ATTN = False + print('please install xformers from https://github.com/facebookresearch/xformers') + if HAS_TRITON: @triton.jit @@ -497,3 +504,22 @@ if HAS_FLASH_ATTN: device=k.device) return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale, causal) + + +if HAS_MEM_EFF_ATTN: + + from einops import rearrange + from xformers.ops.fmha import LowerTriangularMask + + class MemoryEfficientAttention(torch.nn.Module): + + def __init__(self, hidden_size: int, num_attention_heads: int, attention_dropout: float = 0.0): + super().__init__() + attention_head_size = hidden_size // num_attention_heads + self.scale = 1 / attention_head_size**0.5 + self.dropout = attention_dropout + + def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor): + context = memory_efficient_attention(query, key, value, attention_mask, self.dropout, self.scale) + context = rearrange(context, 'b s h d -> b s (h d)') + return context diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index 9d2ee8a18..58e3b21d9 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -2,7 +2,7 @@ import pytest import torch from einops import rearrange -from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON +from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN, HAS_TRITON if HAS_FLASH_ATTN: from colossalai.kernel.cuda_native.flash_attention import ( @@ -15,6 +15,9 @@ if HAS_FLASH_ATTN: if HAS_TRITON: from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention +if HAS_MEM_EFF_ATTN: + from colossalai.kernel.cuda_native.flash_attention import LowerTriangularMask, MemoryEfficientAttention + def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) @@ -124,5 +127,20 @@ def test_masked_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): out.backward(dout) +@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 8, 4, 16)]) +def test_memory_efficient_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): + attn = MemoryEfficientAttention(N_CTX * D_HEAD, N_CTX, 0.1) + + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + + out = attn(q, k, v, attention_mask=LowerTriangularMask()) + + dout = torch.rand_like(out) + out.backward(dout) + + if __name__ == '__main__': test_flash_attention(3, 4, 2, 16)