|
|
|
@ -2,7 +2,7 @@ import pytest
|
|
|
|
|
import torch |
|
|
|
|
from einops import rearrange |
|
|
|
|
|
|
|
|
|
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON, TRITON_AVALIABLE |
|
|
|
|
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON |
|
|
|
|
|
|
|
|
|
if HAS_FLASH_ATTN: |
|
|
|
|
from colossalai.kernel.cuda_native.flash_attention import flash_attention |
|
|
|
@ -22,7 +22,7 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
|
|
|
|
|
return ref_out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available") |
|
|
|
|
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available") |
|
|
|
|
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)]) |
|
|
|
|
def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): |
|
|
|
|
torch.manual_seed(20) |
|
|
|
@ -39,7 +39,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
|
|
|
|
ref_dq, q.grad = q.grad.clone(), None |
|
|
|
|
|
|
|
|
|
# triton implementation |
|
|
|
|
if TRITON_AVALIABLE: |
|
|
|
|
if HAS_TRITON: |
|
|
|
|
tri_out = triton_flash_attention(q, k, v, sm_scale) |
|
|
|
|
tri_out.backward(dout) |
|
|
|
|
tri_dv, v.grad = v.grad.clone(), None |
|
|
|
@ -59,7 +59,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
|
|
|
|
raise TypeError("Error type not match!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available") |
|
|
|
|
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available") |
|
|
|
|
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)]) |
|
|
|
|
def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): |
|
|
|
|
torch.manual_seed(20) |
|
|
|
|