mirror of https://github.com/hpcaitech/ColossalAI
83 lines
3.5 KiB
Python
83 lines
3.5 KiB
Python
|
import torch
|
||
|
import pytest
|
||
|
from einops import rearrange
|
||
|
from colossalai.kernel.cuda_native.flash_attention import flash_attention, triton_flash_attention, TRITON_AVALIABLE
|
||
|
|
||
|
|
||
|
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
|
||
|
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||
|
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
||
|
for z in range(Z):
|
||
|
for h in range(H):
|
||
|
p[:, :, M == 0] = float("-inf")
|
||
|
p = torch.softmax(p.float(), dim=-1).half()
|
||
|
ref_out = torch.matmul(p, v)
|
||
|
return ref_out
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
|
||
|
def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||
|
torch.manual_seed(20)
|
||
|
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||
|
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||
|
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||
|
sm_scale = 0.3
|
||
|
dout = torch.randn_like(q)
|
||
|
|
||
|
ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale)
|
||
|
ref_out.backward(dout)
|
||
|
ref_dv, v.grad = v.grad.clone(), None
|
||
|
ref_dk, k.grad = k.grad.clone(), None
|
||
|
ref_dq, q.grad = q.grad.clone(), None
|
||
|
|
||
|
# triton implementation
|
||
|
if TRITON_AVALIABLE:
|
||
|
tri_out = triton_flash_attention(q, k, v, sm_scale)
|
||
|
tri_out.backward(dout)
|
||
|
tri_dv, v.grad = v.grad.clone(), None
|
||
|
tri_dk, k.grad = k.grad.clone(), None
|
||
|
tri_dq, q.grad = q.grad.clone(), None
|
||
|
# compare
|
||
|
assert torch.allclose(ref_out, tri_out, atol=1e-3)
|
||
|
assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
|
||
|
assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
|
||
|
assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
|
||
|
else:
|
||
|
try:
|
||
|
tri_out = flash_attention(q, k, v, sm_scale, Z, N_CTX)
|
||
|
except RuntimeError:
|
||
|
pass
|
||
|
else:
|
||
|
raise TypeError("Error type not match!")
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
|
||
|
def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||
|
torch.manual_seed(20)
|
||
|
q = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||
|
k = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||
|
v = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||
|
sm_scale = 0.3
|
||
|
dout = torch.randn_like(q)
|
||
|
|
||
|
# reference implementation
|
||
|
ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale)
|
||
|
ref_out.backward(dout)
|
||
|
ref_dv, v.grad = v.grad.clone(), None
|
||
|
ref_dk, k.grad = k.grad.clone(), None
|
||
|
ref_dq, q.grad = q.grad.clone(), None
|
||
|
|
||
|
# flash implementation
|
||
|
q, k, v = map(lambda x: rearrange(x, 'z h n d -> (z n) h d'), [q, k, v])
|
||
|
tri_out = flash_attention(q, k, v, sm_scale, Z, N_CTX)
|
||
|
dout = rearrange(dout, 'z h n d -> (z n) h d').detach()
|
||
|
tri_out.backward(dout, retain_graph=True)
|
||
|
tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout)
|
||
|
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), (tri_out, tri_dq, tri_dk, tri_dv))
|
||
|
|
||
|
# compare
|
||
|
assert torch.allclose(ref_out, tri_out, atol=1e-3)
|
||
|
assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
|
||
|
assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
|
||
|
assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
|