You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_utils/test_flash_attention.py

131 lines
4.3 KiB

import random
import pytest
import torch
from einops import rearrange
from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN
from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN
from colossalai.testing import clear_cache_before_run, parameterize
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
from colossalai.kernel.cuda_native.mha.mha import ColoAttention
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
DTYPE = [torch.float16, torch.bfloat16, torch.float32]
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
ref_out = torch.matmul(p, v)
return ref_out
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('proj_shape', [(1, 8, 4, 16)])
@parameterize('dtype', DTYPE)
def test_attention_gpt(proj_shape, dtype):
# TODO check output value
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
attn = ColoAttention(D, H, dropout=0.1)
x = torch.randn((B, S, D), dtype=dtype, device="cuda")
qkv = c_attn(x)
q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H)
mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal)
assert list(y.shape) == [B, S, D]
dy = torch.rand_like(y)
y.backward(dy)
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('proj_shape', [(6, 8, 4, 16)])
@parameterize('dtype', DTYPE)
def test_attention_bert(proj_shape, dtype):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
attn = ColoAttention(D, H, dropout=0.1)
x = torch.randn((B, S, D), dtype=dtype, device="cuda")
# attention mask of shape [B, S] with zero padding to max length S
mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)
qkv = c_attn(x)
q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2)
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding)
assert list(y.shape) == [B, S, D]
dy = torch.rand_like(y)
y.backward(dy)
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('proj_shape', [(6, 8, 4, 16)])
@parameterize('dtype', DTYPE)
def test_attention_no_mask(proj_shape, dtype):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
attn = ColoAttention(D, H, dropout=0.1)
x = torch.randn((B, S, D), dtype=dtype, device="cuda")
qkv = c_attn(x)
q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2)
y = attn(q, k, v)
assert list(y.shape) == [B, S, D]
dy = torch.rand_like(y)
y.backward(dy)
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('proj_shape', [(6, 24, 8, 4, 16)])
@parameterize('dtype', DTYPE)
def test_cross_attention(proj_shape, dtype):
(B, S, T, H, D_HEAD) = proj_shape
D = H * D_HEAD
q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda")
kv_attn = torch.nn.Linear(D, 2 * D, dtype=dtype, device="cuda")
attn = ColoAttention(D, H, dropout=0.1)
src = torch.randn((B, S, D), dtype=dtype, device="cuda")
tgt = torch.randn((B, T, D), dtype=dtype, device="cuda")
q = q_attn(tgt)
kv = kv_attn(src)
q = rearrange(q, 'b s (h d) -> b s h d', h=H)
k, v = rearrange(kv, 'b s (n h d) -> b s n h d', n=2, h=H).unbind(dim=2)
y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)
assert list(y.shape) == [B, T, D]
dy = torch.rand_like(y)
y.backward(dy)