|
|
|
@ -1,22 +1,13 @@
|
|
|
|
|
import random |
|
|
|
|
|
|
|
|
|
import pytest |
|
|
|
|
import torch |
|
|
|
|
from einops import rearrange |
|
|
|
|
|
|
|
|
|
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN, HAS_TRITON |
|
|
|
|
|
|
|
|
|
if HAS_FLASH_ATTN: |
|
|
|
|
from colossalai.kernel.cuda_native.flash_attention import ( |
|
|
|
|
MaskedFlashAttention, |
|
|
|
|
flash_attention_q_k_v, |
|
|
|
|
flash_attention_q_kv, |
|
|
|
|
flash_attention_qkv, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if HAS_TRITON: |
|
|
|
|
from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention |
|
|
|
|
from colossalai.kernel.cuda_native.flash_attention import HAS_MEM_EFF_ATTN |
|
|
|
|
|
|
|
|
|
if HAS_MEM_EFF_ATTN: |
|
|
|
|
from colossalai.kernel.cuda_native.flash_attention import LowerTriangularMask, MemoryEfficientAttention |
|
|
|
|
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): |
|
|
|
@ -30,117 +21,88 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
|
|
|
|
|
return ref_out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(HAS_TRITON == False, reason="triton is not available") |
|
|
|
|
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)]) |
|
|
|
|
def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): |
|
|
|
|
torch.manual_seed(20) |
|
|
|
|
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() |
|
|
|
|
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() |
|
|
|
|
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() |
|
|
|
|
sm_scale = 0.3 |
|
|
|
|
dout = torch.randn_like(q) |
|
|
|
|
|
|
|
|
|
ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale) |
|
|
|
|
ref_out.backward(dout) |
|
|
|
|
ref_dv, v.grad = v.grad.clone(), None |
|
|
|
|
ref_dk, k.grad = k.grad.clone(), None |
|
|
|
|
ref_dq, q.grad = q.grad.clone(), None |
|
|
|
|
|
|
|
|
|
# triton implementation |
|
|
|
|
tri_out = triton_flash_attention(q, k, v, sm_scale) |
|
|
|
|
tri_out.backward(dout) |
|
|
|
|
tri_dv, v.grad = v.grad.clone(), None |
|
|
|
|
tri_dk, k.grad = k.grad.clone(), None |
|
|
|
|
tri_dq, q.grad = q.grad.clone(), None |
|
|
|
|
# compare |
|
|
|
|
assert torch.allclose(ref_out, tri_out, atol=1e-3) |
|
|
|
|
assert torch.allclose(ref_dv, tri_dv, atol=1e-3) |
|
|
|
|
assert torch.allclose(ref_dk, tri_dk, atol=1e-3) |
|
|
|
|
assert torch.allclose(ref_dq, tri_dq, atol=1e-3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available") |
|
|
|
|
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)]) |
|
|
|
|
def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): |
|
|
|
|
torch.manual_seed(20) |
|
|
|
|
q = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() |
|
|
|
|
k = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() |
|
|
|
|
v = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() |
|
|
|
|
sm_scale = 0.3 |
|
|
|
|
dout = torch.randn_like(q) |
|
|
|
|
|
|
|
|
|
# reference implementation |
|
|
|
|
ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale) |
|
|
|
|
ref_out.backward(dout) |
|
|
|
|
ref_dv, v.grad = v.grad.clone(), None |
|
|
|
|
ref_dk, k.grad = k.grad.clone(), None |
|
|
|
|
ref_dq, q.grad = q.grad.clone(), None |
|
|
|
|
|
|
|
|
|
# flash implementation |
|
|
|
|
q, k, v = map(lambda x: rearrange(x, 'z h n d -> (z n) h d'), [q, k, v]) |
|
|
|
|
dout = rearrange(dout, 'z h n d -> (z n) h d').detach() |
|
|
|
|
for i in range(3): |
|
|
|
|
if i == 0: |
|
|
|
|
tri_out = flash_attention_q_k_v(q, k, v, sm_scale, Z, N_CTX, N_CTX, causal=True) |
|
|
|
|
elif i == 1: |
|
|
|
|
kv = torch.cat((k.unsqueeze(1), v.unsqueeze(1)), dim=1) |
|
|
|
|
tri_out = flash_attention_q_kv(q, kv, sm_scale, Z, N_CTX, N_CTX, causal=True) |
|
|
|
|
else: |
|
|
|
|
qkv = torch.cat((q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1)), dim=1) |
|
|
|
|
tri_out = flash_attention_qkv(qkv, sm_scale, Z, N_CTX, causal=True) |
|
|
|
|
|
|
|
|
|
tri_out.backward(dout, retain_graph=True) |
|
|
|
|
|
|
|
|
|
if i == 0: |
|
|
|
|
tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout) |
|
|
|
|
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), |
|
|
|
|
(tri_out, tri_dq, tri_dk, tri_dv)) |
|
|
|
|
elif i == 1: |
|
|
|
|
tri_dq, tri_dkv, = torch.autograd.grad(tri_out, (q, kv), dout) |
|
|
|
|
tri_dk, tri_dv = torch.chunk(tri_dkv, 2, dim=1) |
|
|
|
|
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), |
|
|
|
|
(tri_out, tri_dq, tri_dk.squeeze(1), tri_dv.squeeze(1))) |
|
|
|
|
else: |
|
|
|
|
tri_dqkv, = torch.autograd.grad(tri_out, (qkv), dout) |
|
|
|
|
tri_dq, tri_dk, tri_dv = torch.chunk(tri_dqkv, 3, dim=1) |
|
|
|
|
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), |
|
|
|
|
(tri_out, tri_dq.squeeze(1), tri_dk.squeeze(1), tri_dv.squeeze(1))) |
|
|
|
|
|
|
|
|
|
# compare |
|
|
|
|
assert torch.allclose(ref_out, tri_out, atol=1e-3) |
|
|
|
|
assert torch.allclose(ref_dv, tri_dv, atol=1e-3) |
|
|
|
|
assert torch.allclose(ref_dk, tri_dk, atol=1e-3) |
|
|
|
|
assert torch.allclose(ref_dq, tri_dq, atol=1e-3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available") |
|
|
|
|
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)]) |
|
|
|
|
def test_masked_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): |
|
|
|
|
attn = MaskedFlashAttention(N_CTX, D_HEAD, 0.1) |
|
|
|
|
|
|
|
|
|
qkv = torch.randn((Z, H, 3 * N_CTX * D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() |
|
|
|
|
attention_mask = torch.randint(2, (Z, H)).cuda().bool() |
|
|
|
|
|
|
|
|
|
out = attn(qkv, attention_mask) |
|
|
|
|
|
|
|
|
|
dout = torch.rand_like(out) |
|
|
|
|
out.backward(dout) |
|
|
|
|
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") |
|
|
|
|
@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) |
|
|
|
|
def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): |
|
|
|
|
D = H * D_HEAD |
|
|
|
|
|
|
|
|
|
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") |
|
|
|
|
attn = ColoAttention(D, H, dropout=0.1) |
|
|
|
|
|
|
|
|
|
x = torch.randn((B, S, D), dtype=dtype, device="cuda") |
|
|
|
|
|
|
|
|
|
qkv = c_attn(x) |
|
|
|
|
q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H) |
|
|
|
|
y = attn(q, k, v, attn_mask_type=AttnMaskType.causal) |
|
|
|
|
|
|
|
|
|
assert list(y.shape) == [B, S, D] |
|
|
|
|
|
|
|
|
|
dy = torch.rand_like(y) |
|
|
|
|
y.backward(dy) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") |
|
|
|
|
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 8, 4, 16)]) |
|
|
|
|
def test_memory_efficient_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): |
|
|
|
|
attn = MemoryEfficientAttention(N_CTX * D_HEAD, N_CTX, 0.1) |
|
|
|
|
@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) |
|
|
|
|
def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16): |
|
|
|
|
D = H * D_HEAD |
|
|
|
|
|
|
|
|
|
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") |
|
|
|
|
attn = ColoAttention(D, H, dropout=0.1) |
|
|
|
|
|
|
|
|
|
x = torch.randn((B, S, D), dtype=dtype, device="cuda") |
|
|
|
|
# attention mask of shape [B, S] with zero padding to max length S |
|
|
|
|
mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)] |
|
|
|
|
mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True) |
|
|
|
|
|
|
|
|
|
qkv = c_attn(x) |
|
|
|
|
q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2) |
|
|
|
|
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding) |
|
|
|
|
|
|
|
|
|
assert list(y.shape) == [B, S, D] |
|
|
|
|
|
|
|
|
|
dy = torch.rand_like(y) |
|
|
|
|
y.backward(dy) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") |
|
|
|
|
@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) |
|
|
|
|
def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16): |
|
|
|
|
D = H * D_HEAD |
|
|
|
|
|
|
|
|
|
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") |
|
|
|
|
attn = ColoAttention(D, H, dropout=0.1) |
|
|
|
|
|
|
|
|
|
x = torch.randn((B, S, D), dtype=dtype, device="cuda") |
|
|
|
|
qkv = c_attn(x) |
|
|
|
|
q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2) |
|
|
|
|
y = attn(q, k, v) |
|
|
|
|
|
|
|
|
|
assert list(y.shape) == [B, S, D] |
|
|
|
|
|
|
|
|
|
dy = torch.rand_like(y) |
|
|
|
|
y.backward(dy) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") |
|
|
|
|
@pytest.mark.parametrize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)]) |
|
|
|
|
def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16): |
|
|
|
|
D = H * D_HEAD |
|
|
|
|
|
|
|
|
|
q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda") |
|
|
|
|
kv_attn = torch.nn.Linear(D, 2 * D, dtype=dtype, device="cuda") |
|
|
|
|
|
|
|
|
|
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() |
|
|
|
|
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() |
|
|
|
|
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() |
|
|
|
|
attn = ColoAttention(D, H, dropout=0.1) |
|
|
|
|
|
|
|
|
|
out = attn(q, k, v, attention_mask=LowerTriangularMask()) |
|
|
|
|
src = torch.randn((B, S, D), dtype=dtype, device="cuda") |
|
|
|
|
tgt = torch.randn((B, T, D), dtype=dtype, device="cuda") |
|
|
|
|
|
|
|
|
|
dout = torch.rand_like(out) |
|
|
|
|
out.backward(dout) |
|
|
|
|
q = q_attn(tgt) |
|
|
|
|
kv = kv_attn(src) |
|
|
|
|
q = rearrange(q, 'b s (h d) -> b s h d', h=H) |
|
|
|
|
k, v = rearrange(kv, 'b s (n h d) -> b s n h d', n=2, h=H).unbind(dim=2) |
|
|
|
|
y = attn(q, k, v, attn_mask_type=AttnMaskType.causal) |
|
|
|
|
|
|
|
|
|
assert list(y.shape) == [B, T, D] |
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
test_flash_attention(3, 4, 2, 16) |
|
|
|
|
dy = torch.rand_like(y) |
|
|
|
|
y.backward(dy) |
|
|
|
|