mirror of https://github.com/hpcaitech/ColossalAI
[kernel] updated unittests for coloattention (#4389)
Updated coloattention tests of checking outputs and gradientspull/4396/head
parent
089c365fa0
commit
458ae331ad
|
@ -13,6 +13,7 @@ torchrec==0.2.0
|
|||
contexttimer
|
||||
einops
|
||||
triton==2.0.0.dev20221202
|
||||
git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
|
||||
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
|
||||
SentencePiece
|
||||
ninja
|
||||
flash_attn>=2.0
|
||||
|
|
|
@ -10,4 +10,5 @@ contexttimer
|
|||
ninja
|
||||
torch>=1.11
|
||||
safetensors
|
||||
flash_attn>=2.0
|
||||
einops
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import random
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -13,118 +13,158 @@ if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
|
|||
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
|
||||
|
||||
DTYPE = [torch.float16, torch.bfloat16, torch.float32]
|
||||
FLASH_DTYPE = [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
||||
for z in range(Z):
|
||||
for h in range(H):
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
ref_out = torch.matmul(p, v)
|
||||
return ref_out
|
||||
def attention_ref(q, k, v, attn_mask=None, causal=False):
|
||||
"""
|
||||
attention output of the control group
|
||||
"""
|
||||
dtype_og = q.dtype
|
||||
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
|
||||
d = q.shape[-1]
|
||||
scale = 1.0 / math.sqrt(d)
|
||||
scores = torch.einsum('bthd,bshd->bhts', q * scale, k)
|
||||
|
||||
if attn_mask is not None:
|
||||
scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf'))
|
||||
if causal:
|
||||
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1)
|
||||
scores.masked_fill_(causal_mask, float('-inf'))
|
||||
attention = torch.softmax(scores, dim=-1)
|
||||
|
||||
output = torch.einsum('bhts,bshd->bthd', attention, v)
|
||||
output = rearrange(output, "b s h d -> b s (h d)")
|
||||
|
||||
# Modify the data at the positions of the mask to 0
|
||||
if attn_mask is not None:
|
||||
output.masked_fill_(rearrange(~attn_mask, 'b s -> b s 1'), 0.0)
|
||||
|
||||
return output.to(dtype=dtype_og)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
|
||||
@clear_cache_before_run()
|
||||
@parameterize('proj_shape', [(1, 8, 4, 16)])
|
||||
@parameterize('proj_shape', [(6, 8, 4, 16)])
|
||||
@parameterize('dtype', DTYPE)
|
||||
def test_attention_gpt(proj_shape, dtype):
|
||||
# TODO check output value
|
||||
@parameterize('dropout', [0.0])
|
||||
def test_attention_gpt(proj_shape, dtype, dropout):
|
||||
(B, S, H, D_HEAD) = proj_shape
|
||||
D = H * D_HEAD
|
||||
|
||||
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
|
||||
attn = ColoAttention(D, H, dropout=0.1)
|
||||
q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
|
||||
x = torch.randn((B, S, D), dtype=dtype, device="cuda")
|
||||
|
||||
qkv = c_attn(x)
|
||||
q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H)
|
||||
|
||||
mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
|
||||
mask = [torch.ones(S - i, dtype=torch.bool, device="cuda") for i in range(B)]
|
||||
mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)
|
||||
|
||||
attn = ColoAttention(D, H, dropout=dropout)
|
||||
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal)
|
||||
|
||||
assert list(y.shape) == [B, S, D]
|
||||
|
||||
out_ref = attention_ref(q, k, v, mask, causal=True)
|
||||
|
||||
# check gradients
|
||||
dy = torch.rand_like(y)
|
||||
y.backward(dy)
|
||||
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
|
||||
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
|
||||
|
||||
torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
|
||||
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
|
||||
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
|
||||
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
|
||||
@clear_cache_before_run()
|
||||
@parameterize('proj_shape', [(6, 8, 4, 16)])
|
||||
@parameterize('dtype', DTYPE)
|
||||
def test_attention_bert(proj_shape, dtype):
|
||||
@parameterize('dropout', [0.0])
|
||||
def test_attention_bert(proj_shape, dtype, dropout):
|
||||
(B, S, H, D_HEAD) = proj_shape
|
||||
D = H * D_HEAD
|
||||
|
||||
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
|
||||
attn = ColoAttention(D, H, dropout=0.1)
|
||||
q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
|
||||
x = torch.randn((B, S, D), dtype=dtype, device="cuda")
|
||||
# attention mask of shape [B, S] with zero padding to max length S
|
||||
mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
|
||||
mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)
|
||||
mask = torch.randint(0, 2, (B, S), dtype=torch.bool, device="cuda")
|
||||
|
||||
qkv = c_attn(x)
|
||||
q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2)
|
||||
attn = ColoAttention(D, H, dropout=dropout)
|
||||
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding)
|
||||
|
||||
assert list(y.shape) == [B, S, D]
|
||||
|
||||
out_ref = attention_ref(q, k, v, mask, causal=False)
|
||||
|
||||
dy = torch.rand_like(y)
|
||||
y.backward(dy)
|
||||
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
|
||||
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
|
||||
|
||||
torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
|
||||
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
|
||||
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
|
||||
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
|
||||
@clear_cache_before_run()
|
||||
@parameterize('proj_shape', [(6, 8, 4, 16)])
|
||||
@parameterize('dtype', DTYPE)
|
||||
def test_attention_no_mask(proj_shape, dtype):
|
||||
@parameterize('dropout', [0.0])
|
||||
def test_attention_no_mask(proj_shape, dtype, dropout):
|
||||
(B, S, H, D_HEAD) = proj_shape
|
||||
D = H * D_HEAD
|
||||
|
||||
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
|
||||
attn = ColoAttention(D, H, dropout=0.1)
|
||||
q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
|
||||
x = torch.randn((B, S, D), dtype=dtype, device="cuda")
|
||||
qkv = c_attn(x)
|
||||
q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2)
|
||||
attn = ColoAttention(D, H, dropout=dropout)
|
||||
y = attn(q, k, v)
|
||||
|
||||
assert list(y.shape) == [B, S, D]
|
||||
|
||||
out_ref = attention_ref(q, k, v, None, causal=False)
|
||||
|
||||
dy = torch.rand_like(y)
|
||||
y.backward(dy)
|
||||
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
|
||||
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
|
||||
|
||||
torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
|
||||
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
|
||||
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
|
||||
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
|
||||
@clear_cache_before_run()
|
||||
@parameterize('proj_shape', [(6, 24, 8, 4, 16)])
|
||||
@parameterize('dtype', DTYPE)
|
||||
def test_cross_attention(proj_shape, dtype):
|
||||
@parameterize('dropout', [0.0])
|
||||
def test_cross_attention(proj_shape, dtype, dropout):
|
||||
(B, S, T, H, D_HEAD) = proj_shape
|
||||
D = H * D_HEAD
|
||||
|
||||
q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda")
|
||||
kv_attn = torch.nn.Linear(D, 2 * D, dtype=dtype, device="cuda")
|
||||
q = torch.randn((B, T, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
|
||||
attn = ColoAttention(D, H, dropout=0.1)
|
||||
|
||||
src = torch.randn((B, S, D), dtype=dtype, device="cuda")
|
||||
tgt = torch.randn((B, T, D), dtype=dtype, device="cuda")
|
||||
|
||||
q = q_attn(tgt)
|
||||
kv = kv_attn(src)
|
||||
q = rearrange(q, 'b s (h d) -> b s h d', h=H)
|
||||
k, v = rearrange(kv, 'b s (n h d) -> b s n h d', n=2, h=H).unbind(dim=2)
|
||||
attn = ColoAttention(D, H, dropout=dropout)
|
||||
y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)
|
||||
|
||||
assert list(y.shape) == [B, T, D]
|
||||
|
||||
out_ref = attention_ref(q, k, v, None, causal=True)
|
||||
|
||||
dy = torch.rand_like(y)
|
||||
y.backward(dy)
|
||||
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
|
||||
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
|
||||
|
||||
torch.allclose(y, out_ref, atol=1e-18), f"{(y - out_ref).abs().max()}"
|
||||
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
|
||||
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
|
||||
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
|
Loading…
Reference in New Issue