2023-08-09 06:24:45 +00:00
|
|
|
import math
|
2023-03-17 07:09:47 +00:00
|
|
|
|
2022-10-26 08:15:52 +00:00
|
|
|
import pytest
|
2022-11-07 05:41:13 +00:00
|
|
|
import torch
|
2022-10-26 08:15:52 +00:00
|
|
|
from einops import rearrange
|
2022-11-07 05:41:13 +00:00
|
|
|
|
2023-08-04 05:46:22 +00:00
|
|
|
from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN
|
|
|
|
from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN
|
2023-04-06 06:51:35 +00:00
|
|
|
from colossalai.testing import clear_cache_before_run, parameterize
|
2022-10-26 08:15:52 +00:00
|
|
|
|
2023-08-04 05:46:22 +00:00
|
|
|
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
|
2023-08-04 08:28:41 +00:00
|
|
|
from colossalai.kernel.cuda_native import ColoAttention
|
2023-08-04 05:46:22 +00:00
|
|
|
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
|
|
|
|
|
|
|
|
DTYPE = [torch.float16, torch.bfloat16, torch.float32]
|
2022-12-16 02:54:03 +00:00
|
|
|
|
2022-10-26 08:15:52 +00:00
|
|
|
|
2023-08-09 06:24:45 +00:00
|
|
|
def attention_ref(q, k, v, attn_mask=None, causal=False):
|
|
|
|
"""
|
|
|
|
attention output of the control group
|
|
|
|
"""
|
|
|
|
dtype_og = q.dtype
|
|
|
|
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
|
|
|
|
d = q.shape[-1]
|
|
|
|
scale = 1.0 / math.sqrt(d)
|
2023-09-19 06:20:26 +00:00
|
|
|
scores = torch.einsum("bthd,bshd->bhts", q * scale, k)
|
2023-08-09 06:24:45 +00:00
|
|
|
|
|
|
|
if attn_mask is not None:
|
2023-09-19 06:20:26 +00:00
|
|
|
scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf"))
|
2023-08-09 06:24:45 +00:00
|
|
|
if causal:
|
|
|
|
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1)
|
2023-09-19 06:20:26 +00:00
|
|
|
scores.masked_fill_(causal_mask, float("-inf"))
|
2023-08-09 06:24:45 +00:00
|
|
|
attention = torch.softmax(scores, dim=-1)
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
output = torch.einsum("bhts,bshd->bthd", attention, v)
|
2023-08-09 06:24:45 +00:00
|
|
|
output = rearrange(output, "b s h d -> b s (h d)")
|
|
|
|
|
|
|
|
# Modify the data at the positions of the mask to 0
|
|
|
|
if attn_mask is not None:
|
2023-09-19 06:20:26 +00:00
|
|
|
output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1"), 0.0)
|
2023-08-09 06:24:45 +00:00
|
|
|
|
|
|
|
return output.to(dtype=dtype_og)
|
2022-10-26 08:15:52 +00:00
|
|
|
|
2022-11-07 05:41:13 +00:00
|
|
|
|
2023-08-04 05:46:22 +00:00
|
|
|
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
|
2023-04-06 06:51:35 +00:00
|
|
|
@clear_cache_before_run()
|
2023-09-19 06:20:26 +00:00
|
|
|
@parameterize("proj_shape", [(6, 8, 4, 16)])
|
|
|
|
@parameterize("dtype", DTYPE)
|
|
|
|
@parameterize("dropout", [0.0])
|
2023-08-09 06:24:45 +00:00
|
|
|
def test_attention_gpt(proj_shape, dtype, dropout):
|
2023-08-04 05:46:22 +00:00
|
|
|
(B, S, H, D_HEAD) = proj_shape
|
2023-03-17 07:09:47 +00:00
|
|
|
D = H * D_HEAD
|
|
|
|
|
2023-08-09 06:24:45 +00:00
|
|
|
q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
|
|
|
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
|
|
|
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
2023-03-17 07:09:47 +00:00
|
|
|
|
2023-08-09 06:24:45 +00:00
|
|
|
mask = [torch.ones(S - i, dtype=torch.bool, device="cuda") for i in range(B)]
|
2023-08-04 05:46:22 +00:00
|
|
|
mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)
|
|
|
|
|
2023-08-09 06:24:45 +00:00
|
|
|
attn = ColoAttention(D, H, dropout=dropout)
|
2023-08-04 05:46:22 +00:00
|
|
|
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal)
|
2023-03-17 07:09:47 +00:00
|
|
|
|
|
|
|
assert list(y.shape) == [B, S, D]
|
|
|
|
|
2023-08-09 06:24:45 +00:00
|
|
|
out_ref = attention_ref(q, k, v, mask, causal=True)
|
|
|
|
|
|
|
|
# check gradients
|
2023-03-17 07:09:47 +00:00
|
|
|
dy = torch.rand_like(y)
|
2023-08-09 06:24:45 +00:00
|
|
|
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
|
|
|
|
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
|
|
|
|
|
|
|
|
torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
|
|
|
|
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
|
|
|
|
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
|
|
|
|
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
|
2022-11-14 09:11:33 +00:00
|
|
|
|
|
|
|
|
2023-08-04 05:46:22 +00:00
|
|
|
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
|
2023-04-06 06:51:35 +00:00
|
|
|
@clear_cache_before_run()
|
2023-09-19 06:20:26 +00:00
|
|
|
@parameterize("proj_shape", [(6, 8, 4, 16)])
|
|
|
|
@parameterize("dtype", DTYPE)
|
|
|
|
@parameterize("dropout", [0.0])
|
2023-08-09 06:24:45 +00:00
|
|
|
def test_attention_bert(proj_shape, dtype, dropout):
|
2023-08-04 05:46:22 +00:00
|
|
|
(B, S, H, D_HEAD) = proj_shape
|
2023-03-17 07:09:47 +00:00
|
|
|
D = H * D_HEAD
|
|
|
|
|
2023-08-09 06:24:45 +00:00
|
|
|
q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
|
|
|
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
|
|
|
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
2023-03-17 07:09:47 +00:00
|
|
|
|
|
|
|
# attention mask of shape [B, S] with zero padding to max length S
|
2023-08-09 06:24:45 +00:00
|
|
|
mask = torch.randint(0, 2, (B, S), dtype=torch.bool, device="cuda")
|
2023-03-17 07:09:47 +00:00
|
|
|
|
2023-08-09 06:24:45 +00:00
|
|
|
attn = ColoAttention(D, H, dropout=dropout)
|
2023-03-17 07:09:47 +00:00
|
|
|
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding)
|
|
|
|
|
|
|
|
assert list(y.shape) == [B, S, D]
|
|
|
|
|
2023-08-09 06:24:45 +00:00
|
|
|
out_ref = attention_ref(q, k, v, mask, causal=False)
|
|
|
|
|
2023-03-17 07:09:47 +00:00
|
|
|
dy = torch.rand_like(y)
|
2023-08-09 06:24:45 +00:00
|
|
|
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
|
|
|
|
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
|
|
|
|
|
|
|
|
torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
|
|
|
|
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
|
|
|
|
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
|
|
|
|
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
|
2023-03-17 07:09:47 +00:00
|
|
|
|
|
|
|
|
2023-08-04 05:46:22 +00:00
|
|
|
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
|
2023-04-06 06:51:35 +00:00
|
|
|
@clear_cache_before_run()
|
2023-09-19 06:20:26 +00:00
|
|
|
@parameterize("proj_shape", [(6, 8, 4, 16)])
|
|
|
|
@parameterize("dtype", DTYPE)
|
|
|
|
@parameterize("dropout", [0.0])
|
2023-08-09 06:24:45 +00:00
|
|
|
def test_attention_no_mask(proj_shape, dtype, dropout):
|
2023-08-04 05:46:22 +00:00
|
|
|
(B, S, H, D_HEAD) = proj_shape
|
2023-03-17 07:09:47 +00:00
|
|
|
D = H * D_HEAD
|
|
|
|
|
2023-08-09 06:24:45 +00:00
|
|
|
q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
|
|
|
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
|
|
|
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
2023-03-17 07:09:47 +00:00
|
|
|
|
2023-08-09 06:24:45 +00:00
|
|
|
attn = ColoAttention(D, H, dropout=dropout)
|
2023-03-17 07:09:47 +00:00
|
|
|
y = attn(q, k, v)
|
|
|
|
|
|
|
|
assert list(y.shape) == [B, S, D]
|
|
|
|
|
2023-08-09 06:24:45 +00:00
|
|
|
out_ref = attention_ref(q, k, v, None, causal=False)
|
|
|
|
|
2023-03-17 07:09:47 +00:00
|
|
|
dy = torch.rand_like(y)
|
2023-08-09 06:24:45 +00:00
|
|
|
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
|
|
|
|
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
|
|
|
|
|
|
|
|
torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
|
|
|
|
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
|
|
|
|
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
|
|
|
|
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
|
2023-03-17 07:09:47 +00:00
|
|
|
|
|
|
|
|
2023-08-04 05:46:22 +00:00
|
|
|
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
|
2023-04-06 06:51:35 +00:00
|
|
|
@clear_cache_before_run()
|
2023-09-19 06:20:26 +00:00
|
|
|
@parameterize("proj_shape", [(6, 24, 8, 4, 16)])
|
|
|
|
@parameterize("dtype", DTYPE)
|
|
|
|
@parameterize("dropout", [0.0])
|
2023-08-09 06:24:45 +00:00
|
|
|
def test_cross_attention(proj_shape, dtype, dropout):
|
2023-08-04 05:46:22 +00:00
|
|
|
(B, S, T, H, D_HEAD) = proj_shape
|
2023-03-17 07:09:47 +00:00
|
|
|
D = H * D_HEAD
|
|
|
|
|
2023-08-09 06:24:45 +00:00
|
|
|
q = torch.randn((B, T, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
|
|
|
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
|
|
|
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
2022-12-16 02:54:03 +00:00
|
|
|
|
2023-08-09 06:24:45 +00:00
|
|
|
attn = ColoAttention(D, H, dropout=dropout)
|
2023-03-17 07:09:47 +00:00
|
|
|
y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)
|
2022-12-16 02:54:03 +00:00
|
|
|
|
2023-03-17 07:09:47 +00:00
|
|
|
assert list(y.shape) == [B, T, D]
|
2022-12-16 02:54:03 +00:00
|
|
|
|
2023-08-09 06:24:45 +00:00
|
|
|
out_ref = attention_ref(q, k, v, None, causal=True)
|
|
|
|
|
2023-03-17 07:09:47 +00:00
|
|
|
dy = torch.rand_like(y)
|
2023-08-09 06:24:45 +00:00
|
|
|
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
|
|
|
|
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
|
|
|
|
|
|
|
|
torch.allclose(y, out_ref, atol=1e-18), f"{(y - out_ref).abs().max()}"
|
|
|
|
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
|
|
|
|
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
|
2023-08-07 08:41:07 +00:00
|
|
|
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
|