diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 9f6580c72..e65271621 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -13,6 +13,7 @@ torchrec==0.2.0 contexttimer einops triton==2.0.0.dev20221202 -git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 SentencePiece +ninja +flash_attn>=2.0 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index f6be6a624..65eecce2c 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -10,4 +10,5 @@ contexttimer ninja torch>=1.11 safetensors +flash_attn>=2.0 einops diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index fbcc45265..e1c7446f4 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -1,4 +1,4 @@ -import random +import math import pytest import torch @@ -13,118 +13,158 @@ if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN: from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType DTYPE = [torch.float16, torch.bfloat16, torch.float32] +FLASH_DTYPE = [torch.float16, torch.bfloat16] -def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): - M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - for z in range(Z): - for h in range(H): - p[:, :, M == 0] = float("-inf") - p = torch.softmax(p.float(), dim=-1).half() - ref_out = torch.matmul(p, v) - return ref_out +def attention_ref(q, k, v, attn_mask=None, causal=False): + """ + attention output of the control group + """ + dtype_og = q.dtype + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + d = q.shape[-1] + scale = 1.0 / math.sqrt(d) + scores = torch.einsum('bthd,bshd->bhts', q * scale, k) + + if attn_mask is not None: + scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf')) + if causal: + causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1) + scores.masked_fill_(causal_mask, float('-inf')) + attention = torch.softmax(scores, dim=-1) + + output = torch.einsum('bhts,bshd->bthd', attention, v) + output = rearrange(output, "b s h d -> b s (h d)") + + # Modify the data at the positions of the mask to 0 + if attn_mask is not None: + output.masked_fill_(rearrange(~attn_mask, 'b s -> b s 1'), 0.0) + + return output.to(dtype=dtype_og) @pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() -@parameterize('proj_shape', [(1, 8, 4, 16)]) +@parameterize('proj_shape', [(6, 8, 4, 16)]) @parameterize('dtype', DTYPE) -def test_attention_gpt(proj_shape, dtype): - # TODO check output value +@parameterize('dropout', [0.0]) +def test_attention_gpt(proj_shape, dtype, dropout): (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD - c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") - attn = ColoAttention(D, H, dropout=0.1) + q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - x = torch.randn((B, S, D), dtype=dtype, device="cuda") - - qkv = c_attn(x) - q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H) - - mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)] + mask = [torch.ones(S - i, dtype=torch.bool, device="cuda") for i in range(B)] mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True) + attn = ColoAttention(D, H, dropout=dropout) y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal) assert list(y.shape) == [B, S, D] + out_ref = attention_ref(q, k, v, mask, causal=True) + + # check gradients dy = torch.rand_like(y) - y.backward(dy) + grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) + grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) + + torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}" + torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" + torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" + torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" @pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() @parameterize('proj_shape', [(6, 8, 4, 16)]) @parameterize('dtype', DTYPE) -def test_attention_bert(proj_shape, dtype): +@parameterize('dropout', [0.0]) +def test_attention_bert(proj_shape, dtype, dropout): (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD - c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") - attn = ColoAttention(D, H, dropout=0.1) + q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - x = torch.randn((B, S, D), dtype=dtype, device="cuda") # attention mask of shape [B, S] with zero padding to max length S - mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)] - mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True) + mask = torch.randint(0, 2, (B, S), dtype=torch.bool, device="cuda") - qkv = c_attn(x) - q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2) + attn = ColoAttention(D, H, dropout=dropout) y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding) assert list(y.shape) == [B, S, D] + out_ref = attention_ref(q, k, v, mask, causal=False) + dy = torch.rand_like(y) - y.backward(dy) + grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) + grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) + + torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}" + torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" + torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" + torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" @pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() @parameterize('proj_shape', [(6, 8, 4, 16)]) @parameterize('dtype', DTYPE) -def test_attention_no_mask(proj_shape, dtype): +@parameterize('dropout', [0.0]) +def test_attention_no_mask(proj_shape, dtype, dropout): (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD - c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") - attn = ColoAttention(D, H, dropout=0.1) + q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - x = torch.randn((B, S, D), dtype=dtype, device="cuda") - qkv = c_attn(x) - q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2) + attn = ColoAttention(D, H, dropout=dropout) y = attn(q, k, v) assert list(y.shape) == [B, S, D] + out_ref = attention_ref(q, k, v, None, causal=False) + dy = torch.rand_like(y) - y.backward(dy) + grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) + grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) + + torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}" + torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" + torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" + torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" @pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() @parameterize('proj_shape', [(6, 24, 8, 4, 16)]) @parameterize('dtype', DTYPE) -def test_cross_attention(proj_shape, dtype): +@parameterize('dropout', [0.0]) +def test_cross_attention(proj_shape, dtype, dropout): (B, S, T, H, D_HEAD) = proj_shape D = H * D_HEAD - q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda") - kv_attn = torch.nn.Linear(D, 2 * D, dtype=dtype, device="cuda") + q = torch.randn((B, T, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - attn = ColoAttention(D, H, dropout=0.1) - - src = torch.randn((B, S, D), dtype=dtype, device="cuda") - tgt = torch.randn((B, T, D), dtype=dtype, device="cuda") - - q = q_attn(tgt) - kv = kv_attn(src) - q = rearrange(q, 'b s (h d) -> b s h d', h=H) - k, v = rearrange(kv, 'b s (n h d) -> b s n h d', n=2, h=H).unbind(dim=2) + attn = ColoAttention(D, H, dropout=dropout) y = attn(q, k, v, attn_mask_type=AttnMaskType.causal) assert list(y.shape) == [B, T, D] + out_ref = attention_ref(q, k, v, None, causal=True) + dy = torch.rand_like(y) - y.backward(dy) + grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) + grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) + + torch.allclose(y, out_ref, atol=1e-18), f"{(y - out_ref).abs().max()}" + torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" + torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" + torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" \ No newline at end of file