|
|
|
@ -2,7 +2,7 @@ import pytest
|
|
|
|
|
import torch
|
|
|
|
|
from einops import rearrange
|
|
|
|
|
|
|
|
|
|
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON, TRITON_AVALIABLE
|
|
|
|
|
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON
|
|
|
|
|
|
|
|
|
|
if HAS_FLASH_ATTN:
|
|
|
|
|
from colossalai.kernel.cuda_native.flash_attention import flash_attention
|
|
|
|
@ -22,7 +22,7 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
|
|
|
|
|
return ref_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available")
|
|
|
|
|
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
|
|
|
|
|
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
|
|
|
|
|
def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
|
|
|
|
torch.manual_seed(20)
|
|
|
|
@ -39,7 +39,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
|
|
|
|
ref_dq, q.grad = q.grad.clone(), None
|
|
|
|
|
|
|
|
|
|
# triton implementation
|
|
|
|
|
if TRITON_AVALIABLE:
|
|
|
|
|
if HAS_TRITON:
|
|
|
|
|
tri_out = triton_flash_attention(q, k, v, sm_scale)
|
|
|
|
|
tri_out.backward(dout)
|
|
|
|
|
tri_dv, v.grad = v.grad.clone(), None
|
|
|
|
@ -59,7 +59,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
|
|
|
|
raise TypeError("Error type not match!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available")
|
|
|
|
|
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
|
|
|
|
|
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
|
|
|
|
|
def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
|
|
|
|
torch.manual_seed(20)
|
|
|
|
|