From 8ff7d0c78048b4231e1f772a8227282ae7d5822a Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Mon, 14 Oct 2024 18:16:03 +0800 Subject: [PATCH] fix --- tests/test_shardformer/test_layer/test_ring_attn.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index bcb2c1f8a..0ffea2016 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -17,11 +17,10 @@ from colossalai.utils import get_current_device @parameterize("nheads", [5]) @parameterize("d", [128]) @parameterize("dtype", [torch.bfloat16, torch.float16]) -def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size): +def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size): torch.cuda.manual_seed(2) device = get_current_device() sp_group = dist.group.WORLD - sp_size = dist.get_world_size() # Some outliers may seem large, but our errors are still lower than # than Megatron-LM context parallel's # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) @@ -43,7 +42,7 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size): sp_group, AttnMaskType.CAUSAL, return_softmax=True, - inner_ring_size=max(2, sp_size // 2), + inner_ring_size=inner_ring_size, ) ring_out = ring_out.transpose(1, 2) out, lse, _ = flash_attn_qkvpacked_func( @@ -160,12 +159,12 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): def launch_single_ring(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) check_packed_seq() - check_ring_attn(sp_size=world_size) + check_ring_attn(inner_ring_size=None) def launch_double_ring(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) - check_ring_attn(sp_size=world_size) + check_ring_attn(inner_ring_size=2) @rerun_if_address_is_in_use()