|
|
|
@ -2,7 +2,7 @@ import pytest
|
|
|
|
|
import torch |
|
|
|
|
from einops import rearrange |
|
|
|
|
|
|
|
|
|
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON |
|
|
|
|
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN, HAS_TRITON |
|
|
|
|
|
|
|
|
|
if HAS_FLASH_ATTN: |
|
|
|
|
from colossalai.kernel.cuda_native.flash_attention import ( |
|
|
|
@ -15,6 +15,9 @@ if HAS_FLASH_ATTN:
|
|
|
|
|
if HAS_TRITON: |
|
|
|
|
from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention |
|
|
|
|
|
|
|
|
|
if HAS_MEM_EFF_ATTN: |
|
|
|
|
from colossalai.kernel.cuda_native.flash_attention import LowerTriangularMask, MemoryEfficientAttention |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): |
|
|
|
|
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) |
|
|
|
@ -124,5 +127,20 @@ def test_masked_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
|
|
|
|
out.backward(dout) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") |
|
|
|
|
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 8, 4, 16)]) |
|
|
|
|
def test_memory_efficient_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): |
|
|
|
|
attn = MemoryEfficientAttention(N_CTX * D_HEAD, N_CTX, 0.1) |
|
|
|
|
|
|
|
|
|
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() |
|
|
|
|
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() |
|
|
|
|
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() |
|
|
|
|
|
|
|
|
|
out = attn(q, k, v, attention_mask=LowerTriangularMask()) |
|
|
|
|
|
|
|
|
|
dout = torch.rand_like(out) |
|
|
|
|
out.backward(dout) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
test_flash_attention(3, 4, 2, 16) |
|
|
|
|