import torch import torch.distributed as dist import torch.nn.functional as F from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func from torch.testing import assert_close import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device @parameterize("seq_len", [4096]) @parameterize("bs", [2]) @parameterize("nheads", [5]) @parameterize("d", [128]) @parameterize("dtype", [torch.bfloat16, torch.float16]) def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size): torch.cuda.manual_seed(2) device = get_current_device() sp_group = dist.group.WORLD dp_size, pp_size, tp_size = 1, 1, 1 sp_size = dist.get_world_size() sp_axis = 2 pg_mesh = ProcessGroupMesh(dp_size, pp_size, sp_size, tp_size) # Some outliers may seem large, but our errors are still lower than # than Megatron-LM context parallel's # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) # and the original zigzag implementation's (https://github.com/zhuzilin/ring-flash-attention/tree/main) atol = rtol = 7e-3 # Setup inputs qkv = torch.randn(bs, seq_len, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) local_qkv = split_batch_zigzag(qkv, sp_group) q, k, v = local_qkv.unbind(dim=-3) q, k, v = [x.squeeze(2).detach().clone().transpose(1, 2) for x in (q, k, v)] # (B, nHeads, Sq, D) q.requires_grad = k.requires_grad = v.requires_grad = True # Ring attention vs single GPU ring_out, ring_lse = RingAttention.attention( q, k, v, sp_axis, AttnMaskType.CAUSAL, return_softmax=True, inner_ring_size=inner_ring_size, pg_mesh=pg_mesh, ) ring_out = ring_out.transpose(1, 2) out, lse, _ = flash_attn_qkvpacked_func( qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True ) # Checkout out and softmax denominator local_out = split_batch_zigzag(out, sp_group) local_lse = split_batch_zigzag(lse, sp_group, seq_dim=-1) local_lse = local_lse.transpose(1, 2).contiguous().view(-1, ring_lse.shape[-1]) # (B, nHeads, Sq) -> (T, nHeads) assert_close(ring_lse, local_lse, atol=atol, rtol=rtol) assert_close(ring_out, local_out, atol=atol, rtol=rtol) # Check grads ring_out.sum().backward() out.sum().backward() ring_dq, ring_dk, ring_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)] dqkv = qkv.grad local_dqkv = split_batch_zigzag(dqkv, sp_group) assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol) assert_close(ring_dk, local_dqkv[:, :, 1], atol=atol, rtol=rtol) assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol) if dist.get_rank() == 0: print( f"sp_size {dist.get_world_size()}, inner ring size {dist.get_world_size(RingAttention.INNER_RING_GROUP)} passed." ) @parameterize("seqlen", [4096]) @parameterize("bs", [2]) @parameterize("nheads", [5]) @parameterize("d", [128]) @parameterize("dtype", [torch.bfloat16, torch.float16]) def check_packed_seq(seqlen, bs, nheads, d, dtype): device = get_current_device() sp_group = dist.group.WORLD sp_size = dist.get_world_size() sp_axis = 2 atol = rtol = 7e-3 torch.cuda.manual_seed(2) # Prepare varlen attention mask padding_mask = torch.ones((bs, seqlen), dtype=torch.int, device=device) padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0 padding_mask[:, seqlen // 2 :] = 0 input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) # Forward # out = ColoAttention.attention(q, k, v, **mask_info) flat_input = input_embeds.view(-1, nheads, d)[padding_mask.flatten().nonzero().squeeze()] qkv = torch.stack([flat_input] * 3, dim=1) qkv.retain_grad() input_embeds, mask_info, _ = RingAttention.prepare_varlen_batch(padding_mask, sp_group, input_embeds) out, lse, _ = flash_attn_varlen_qkvpacked_func( qkv, mask_info["cu_seqlens"] * sp_size, mask_info["max_seqlen"] * sp_size, return_attn_probs=True, causal=True, # deterministic=True ) # Test the splitting function local_input = split_varlen_zigzag( flat_input, mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size ) assert (local_input == input_embeds.view(-1, nheads, d)[mask_info["valid_indices"]]).all() del local_input, flat_input q_ring, k_ring, v_ring = [input_embeds.clone().transpose(1, 2) for _ in range(3)] q_ring.retain_grad() k_ring.retain_grad() v_ring.retain_grad() ring_out, ring_lse = RingAttention.attention( q_ring, k_ring, v_ring, sp_axis, **mask_info, pad_output=False, return_softmax=True, pg_mesh=ProcessGroupMesh(1, 1, sp_size, 1), # deterministic=True ) ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d) # Check output lse = lse.transpose(0, 1) out, lse = split_varlen_zigzag( [out, lse], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size ) assert_close(lse, ring_lse, atol=atol, rtol=rtol) assert_close(out, ring_out, atol=atol, rtol=rtol) # Check grads labels = torch.ones(out.shape[0], dtype=dtype, device=device) F.mse_loss(out.sum((-2, -1)), labels).backward() F.mse_loss(ring_out.sum((-2, -1)), labels[: ring_out.shape[0]]).backward() dq, dk, dv = [ split_varlen_zigzag( qkv.grad[:, i], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size ) for i in range(3) ] dq_ring, dk_ring, dv_ring = [ x.transpose(1, 2).reshape(-1, nheads, d)[mask_info["valid_indices"]] for x in (q_ring.grad, k_ring.grad, v_ring.grad) ] assert_close(dq, dq_ring, atol=atol, rtol=rtol) assert_close(dk, dk_ring, atol=atol, rtol=rtol) assert_close(dv, dv_ring, atol=atol, rtol=rtol) def launch_single_ring(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) check_packed_seq() check_ring_attn(inner_ring_size=None) def launch_double_ring(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) check_ring_attn(inner_ring_size=2) @rerun_if_address_is_in_use() @parameterize("world_size", [2]) def test_ring_attn(world_size): spawn(launch_single_ring, nprocs=world_size) @rerun_if_address_is_in_use() @parameterize("world_size", [4]) def test_double_ring(world_size): spawn(launch_double_ring, nprocs=world_size) if __name__ == "__main__": test_ring_attn() test_double_ring()