mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
192 lines
7.0 KiB
192 lines
7.0 KiB
import torch |
|
import torch.distributed as dist |
|
import torch.nn.functional as F |
|
from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func |
|
from torch.testing import assert_close |
|
|
|
import colossalai |
|
from colossalai.cluster import ProcessGroupMesh |
|
from colossalai.shardformer.layer import AttnMaskType |
|
from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention |
|
from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag |
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn |
|
from colossalai.utils import get_current_device |
|
|
|
|
|
@parameterize("seq_len", [4096]) |
|
@parameterize("bs", [2]) |
|
@parameterize("nheads", [5]) |
|
@parameterize("d", [128]) |
|
@parameterize("dtype", [torch.bfloat16, torch.float16]) |
|
def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size): |
|
torch.cuda.manual_seed(2) |
|
device = get_current_device() |
|
sp_group = dist.group.WORLD |
|
dp_size, pp_size, tp_size = 1, 1, 1 |
|
sp_size = dist.get_world_size() |
|
sp_axis = 2 |
|
pg_mesh = ProcessGroupMesh(dp_size, pp_size, sp_size, tp_size) |
|
# Some outliers may seem large, but our errors are still lower than |
|
# than Megatron-LM context parallel's |
|
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) |
|
# and the original zigzag implementation's (https://github.com/zhuzilin/ring-flash-attention/tree/main) |
|
atol = rtol = 7e-3 |
|
|
|
# Setup inputs |
|
qkv = torch.randn(bs, seq_len, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
local_qkv = split_batch_zigzag(qkv, sp_group) |
|
q, k, v = local_qkv.unbind(dim=-3) |
|
q, k, v = [x.squeeze(2).detach().clone().transpose(1, 2) for x in (q, k, v)] # (B, nHeads, Sq, D) |
|
q.requires_grad = k.requires_grad = v.requires_grad = True |
|
|
|
# Ring attention vs single GPU |
|
ring_out, ring_lse = RingAttention.attention( |
|
q, |
|
k, |
|
v, |
|
sp_axis, |
|
AttnMaskType.CAUSAL, |
|
return_softmax=True, |
|
inner_ring_size=inner_ring_size, |
|
pg_mesh=pg_mesh, |
|
) |
|
ring_out = ring_out.transpose(1, 2) |
|
out, lse, _ = flash_attn_qkvpacked_func( |
|
qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True |
|
) |
|
|
|
# Checkout out and softmax denominator |
|
local_out = split_batch_zigzag(out, sp_group) |
|
local_lse = split_batch_zigzag(lse, sp_group, seq_dim=-1) |
|
local_lse = local_lse.transpose(1, 2).contiguous().view(-1, ring_lse.shape[-1]) # (B, nHeads, Sq) -> (T, nHeads) |
|
assert_close(ring_lse, local_lse, atol=atol, rtol=rtol) |
|
assert_close(ring_out, local_out, atol=atol, rtol=rtol) |
|
|
|
# Check grads |
|
ring_out.sum().backward() |
|
out.sum().backward() |
|
ring_dq, ring_dk, ring_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)] |
|
dqkv = qkv.grad |
|
local_dqkv = split_batch_zigzag(dqkv, sp_group) |
|
|
|
assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol) |
|
assert_close(ring_dk, local_dqkv[:, :, 1], atol=atol, rtol=rtol) |
|
assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol) |
|
if dist.get_rank() == 0: |
|
print( |
|
f"sp_size {dist.get_world_size()}, inner ring size {dist.get_world_size(RingAttention.INNER_RING_GROUP)} passed." |
|
) |
|
|
|
|
|
@parameterize("seqlen", [4096]) |
|
@parameterize("bs", [2]) |
|
@parameterize("nheads", [5]) |
|
@parameterize("d", [128]) |
|
@parameterize("dtype", [torch.bfloat16, torch.float16]) |
|
def check_packed_seq(seqlen, bs, nheads, d, dtype): |
|
device = get_current_device() |
|
sp_group = dist.group.WORLD |
|
sp_size = dist.get_world_size() |
|
sp_axis = 2 |
|
atol = rtol = 7e-3 |
|
torch.cuda.manual_seed(2) |
|
# Prepare varlen attention mask |
|
padding_mask = torch.ones((bs, seqlen), dtype=torch.int, device=device) |
|
padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0 |
|
padding_mask[:, seqlen // 2 :] = 0 |
|
|
|
input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
|
|
# Forward |
|
# out = ColoAttention.attention(q, k, v, **mask_info) |
|
flat_input = input_embeds.view(-1, nheads, d)[padding_mask.flatten().nonzero().squeeze()] |
|
qkv = torch.stack([flat_input] * 3, dim=1) |
|
qkv.retain_grad() |
|
|
|
input_embeds, mask_info, _ = RingAttention.prepare_varlen_batch(padding_mask, sp_group, input_embeds) |
|
out, lse, _ = flash_attn_varlen_qkvpacked_func( |
|
qkv, |
|
mask_info["cu_seqlens"] * sp_size, |
|
mask_info["max_seqlen"] * sp_size, |
|
return_attn_probs=True, |
|
causal=True, |
|
# deterministic=True |
|
) |
|
# Test the splitting function |
|
local_input = split_varlen_zigzag( |
|
flat_input, mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size |
|
) |
|
assert (local_input == input_embeds.view(-1, nheads, d)[mask_info["valid_indices"]]).all() |
|
del local_input, flat_input |
|
|
|
q_ring, k_ring, v_ring = [input_embeds.clone().transpose(1, 2) for _ in range(3)] |
|
q_ring.retain_grad() |
|
k_ring.retain_grad() |
|
v_ring.retain_grad() |
|
|
|
ring_out, ring_lse = RingAttention.attention( |
|
q_ring, |
|
k_ring, |
|
v_ring, |
|
sp_axis, |
|
**mask_info, |
|
pad_output=False, |
|
return_softmax=True, |
|
pg_mesh=ProcessGroupMesh(1, 1, sp_size, 1), |
|
# deterministic=True |
|
) |
|
ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d) |
|
# Check output |
|
lse = lse.transpose(0, 1) |
|
out, lse = split_varlen_zigzag( |
|
[out, lse], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size |
|
) |
|
assert_close(lse, ring_lse, atol=atol, rtol=rtol) |
|
assert_close(out, ring_out, atol=atol, rtol=rtol) |
|
|
|
# Check grads |
|
labels = torch.ones(out.shape[0], dtype=dtype, device=device) |
|
F.mse_loss(out.sum((-2, -1)), labels).backward() |
|
F.mse_loss(ring_out.sum((-2, -1)), labels[: ring_out.shape[0]]).backward() |
|
dq, dk, dv = [ |
|
split_varlen_zigzag( |
|
qkv.grad[:, i], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size |
|
) |
|
for i in range(3) |
|
] |
|
dq_ring, dk_ring, dv_ring = [ |
|
x.transpose(1, 2).reshape(-1, nheads, d)[mask_info["valid_indices"]] |
|
for x in (q_ring.grad, k_ring.grad, v_ring.grad) |
|
] |
|
|
|
assert_close(dq, dq_ring, atol=atol, rtol=rtol) |
|
assert_close(dk, dk_ring, atol=atol, rtol=rtol) |
|
assert_close(dv, dv_ring, atol=atol, rtol=rtol) |
|
|
|
|
|
def launch_single_ring(rank, world_size, port): |
|
colossalai.launch(rank, world_size, "localhost", port) |
|
check_packed_seq() |
|
check_ring_attn(inner_ring_size=None) |
|
|
|
|
|
def launch_double_ring(rank, world_size, port): |
|
colossalai.launch(rank, world_size, "localhost", port) |
|
check_ring_attn(inner_ring_size=2) |
|
|
|
|
|
@rerun_if_address_is_in_use() |
|
@parameterize("world_size", [2]) |
|
def test_ring_attn(world_size): |
|
spawn(launch_single_ring, nprocs=world_size) |
|
|
|
|
|
@rerun_if_address_is_in_use() |
|
@parameterize("world_size", [4]) |
|
def test_double_ring(world_size): |
|
spawn(launch_double_ring, nprocs=world_size) |
|
|
|
|
|
if __name__ == "__main__": |
|
test_ring_attn() |
|
test_double_ring()
|
|
|