From 501a9e9cd24a52dfa46118c54229cf5b8fa354e3 Mon Sep 17 00:00:00 2001 From: oahzxl <43881818+oahzxl@users.noreply.github.com> Date: Mon, 7 Nov 2022 14:30:22 +0800 Subject: [PATCH] [hotfix] polish flash attention (#1802) --- .../kernel/cuda_native/flash_attention.py | 37 ++++++++++--------- tests/test_utils/test_flash_attention.py | 8 ++-- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/colossalai/kernel/cuda_native/flash_attention.py b/colossalai/kernel/cuda_native/flash_attention.py index 91273622f..d037b89f8 100644 --- a/colossalai/kernel/cuda_native/flash_attention.py +++ b/colossalai/kernel/cuda_native/flash_attention.py @@ -10,20 +10,6 @@ import subprocess import torch -try: - import triton - import triton.language as tl - HAS_TRITON = True -except ImportError: - print('please install triton from https://github.com/openai/triton') - HAS_TRITON = False -try: - from flash_attn.flash_attn_interface import flash_attn_unpadded_func - HAS_FLASH_ATTN = True -except ImportError: - HAS_FLASH_ATTN = False - print('please install flash_attn from https://github.com/HazyResearch/flash-attention') - def triton_check(): cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda") @@ -38,9 +24,26 @@ def triton_check(): return False -TRITON_AVALIABLE = triton_check() +try: + import triton + import triton.language as tl + if triton_check(): + HAS_TRITON = True + else: + print("triton requires cuda >= 11.4") + HAS_TRITON = False +except ImportError: + print('please install triton from https://github.com/openai/triton') + HAS_TRITON = False +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_func + HAS_FLASH_ATTN = True +except ImportError: + HAS_FLASH_ATTN = False + print('please install flash_attn from https://github.com/HazyResearch/flash-attention') + -if TRITON_AVALIABLE: +if HAS_TRITON: @triton.jit def _fwd_kernel( @@ -394,7 +397,7 @@ if TRITON_AVALIABLE: Return: out: (batch, nheads, seq, headdim) """ - if TRITON_AVALIABLE: + if HAS_TRITON: return _TritonFlashAttention.apply(q, k, v, sm_scale) else: raise RuntimeError("Triton kernel requires CUDA 11.4+!") diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index 41b145c58..195de0d28 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -2,7 +2,7 @@ import pytest import torch from einops import rearrange -from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON, TRITON_AVALIABLE +from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON if HAS_FLASH_ATTN: from colossalai.kernel.cuda_native.flash_attention import flash_attention @@ -22,7 +22,7 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): return ref_out -@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available") +@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available") @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)]) def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): torch.manual_seed(20) @@ -39,7 +39,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): ref_dq, q.grad = q.grad.clone(), None # triton implementation - if TRITON_AVALIABLE: + if HAS_TRITON: tri_out = triton_flash_attention(q, k, v, sm_scale) tri_out.backward(dout) tri_dv, v.grad = v.grad.clone(), None @@ -59,7 +59,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): raise TypeError("Error type not match!") -@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available") +@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available") @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)]) def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): torch.manual_seed(20)