mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
113 lines
3.6 KiB
113 lines
3.6 KiB
import random |
|
|
|
import pytest |
|
import torch |
|
from einops import rearrange |
|
|
|
from colossalai.kernel.cuda_native.flash_attention import HAS_MEM_EFF_ATTN |
|
from colossalai.testing import clear_cache_before_run, parameterize |
|
|
|
if HAS_MEM_EFF_ATTN: |
|
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention |
|
|
|
|
|
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): |
|
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) |
|
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale |
|
for z in range(Z): |
|
for h in range(H): |
|
p[:, :, M == 0] = float("-inf") |
|
p = torch.softmax(p.float(), dim=-1).half() |
|
ref_out = torch.matmul(p, v) |
|
return ref_out |
|
|
|
|
|
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") |
|
@clear_cache_before_run() |
|
@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) |
|
def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): |
|
D = H * D_HEAD |
|
|
|
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") |
|
attn = ColoAttention(D, H, dropout=0.1) |
|
|
|
x = torch.randn((B, S, D), dtype=dtype, device="cuda") |
|
|
|
qkv = c_attn(x) |
|
q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H) |
|
y = attn(q, k, v, attn_mask_type=AttnMaskType.causal) |
|
|
|
assert list(y.shape) == [B, S, D] |
|
|
|
dy = torch.rand_like(y) |
|
y.backward(dy) |
|
|
|
|
|
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") |
|
@clear_cache_before_run() |
|
@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) |
|
def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16): |
|
D = H * D_HEAD |
|
|
|
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") |
|
attn = ColoAttention(D, H, dropout=0.1) |
|
|
|
x = torch.randn((B, S, D), dtype=dtype, device="cuda") |
|
# attention mask of shape [B, S] with zero padding to max length S |
|
mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)] |
|
mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True) |
|
|
|
qkv = c_attn(x) |
|
q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2) |
|
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding) |
|
|
|
assert list(y.shape) == [B, S, D] |
|
|
|
dy = torch.rand_like(y) |
|
y.backward(dy) |
|
|
|
|
|
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") |
|
@clear_cache_before_run() |
|
@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) |
|
def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16): |
|
D = H * D_HEAD |
|
|
|
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") |
|
attn = ColoAttention(D, H, dropout=0.1) |
|
|
|
x = torch.randn((B, S, D), dtype=dtype, device="cuda") |
|
qkv = c_attn(x) |
|
q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2) |
|
y = attn(q, k, v) |
|
|
|
assert list(y.shape) == [B, S, D] |
|
|
|
dy = torch.rand_like(y) |
|
y.backward(dy) |
|
|
|
|
|
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") |
|
@clear_cache_before_run() |
|
@parameterize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)]) |
|
def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16): |
|
D = H * D_HEAD |
|
|
|
q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda") |
|
kv_attn = torch.nn.Linear(D, 2 * D, dtype=dtype, device="cuda") |
|
|
|
attn = ColoAttention(D, H, dropout=0.1) |
|
|
|
src = torch.randn((B, S, D), dtype=dtype, device="cuda") |
|
tgt = torch.randn((B, T, D), dtype=dtype, device="cuda") |
|
|
|
q = q_attn(tgt) |
|
kv = kv_attn(src) |
|
q = rearrange(q, 'b s (h d) -> b s h d', h=H) |
|
k, v = rearrange(kv, 'b s (n h d) -> b s n h d', n=2, h=H).unbind(dim=2) |
|
y = attn(q, k, v, attn_mask_type=AttnMaskType.causal) |
|
|
|
assert list(y.shape) == [B, T, D] |
|
|
|
dy = torch.rand_like(y) |
|
y.backward(dy)
|
|
|