[hotfix] polish flash attention (#1802)

pull/1796/head^2
oahzxl 2 years ago committed by GitHub
parent 218c75fd9d
commit 501a9e9cd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,20 +10,6 @@ import subprocess
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
print('please install triton from https://github.com/openai/triton')
HAS_TRITON = False
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
HAS_FLASH_ATTN = True
except ImportError:
HAS_FLASH_ATTN = False
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
def triton_check():
cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda")
@ -38,9 +24,26 @@ def triton_check():
return False
TRITON_AVALIABLE = triton_check()
try:
import triton
import triton.language as tl
if triton_check():
HAS_TRITON = True
else:
print("triton requires cuda >= 11.4")
HAS_TRITON = False
except ImportError:
print('please install triton from https://github.com/openai/triton')
HAS_TRITON = False
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
HAS_FLASH_ATTN = True
except ImportError:
HAS_FLASH_ATTN = False
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
if TRITON_AVALIABLE:
if HAS_TRITON:
@triton.jit
def _fwd_kernel(
@ -394,7 +397,7 @@ if TRITON_AVALIABLE:
Return:
out: (batch, nheads, seq, headdim)
"""
if TRITON_AVALIABLE:
if HAS_TRITON:
return _TritonFlashAttention.apply(q, k, v, sm_scale)
else:
raise RuntimeError("Triton kernel requires CUDA 11.4+!")

@ -2,7 +2,7 @@ import pytest
import torch
from einops import rearrange
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON, TRITON_AVALIABLE
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON
if HAS_FLASH_ATTN:
from colossalai.kernel.cuda_native.flash_attention import flash_attention
@ -22,7 +22,7 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
return ref_out
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available")
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
@ -39,7 +39,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
ref_dq, q.grad = q.grad.clone(), None
# triton implementation
if TRITON_AVALIABLE:
if HAS_TRITON:
tri_out = triton_flash_attention(q, k, v, sm_scale)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
@ -59,7 +59,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
raise TypeError("Error type not match!")
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available")
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)

Loading…
Cancel
Save