From 3fab92166e9547359296d29ebda5b4ce03437502 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 26 Sep 2024 18:03:09 +0800 Subject: [PATCH] fix --- colossalai/shardformer/layer/attn.py | 46 +++++++++++++++++++--------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index b7738c7e2..21dbe718e 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -442,8 +442,7 @@ class RingAttention(torch.autograd.Function): Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. """ sp_size = dist.get_world_size(sp_group) - dist.get_rank(sp_group) - dist.get_world_size(tp_group) + tp_size = dist.get_world_size(tp_group) if inner_ring_size is None: if torch.cuda.device_count() >= dist.get_world_size(): @@ -476,21 +475,38 @@ class RingAttention(torch.autograd.Function): rank = dist.get_rank() groups = int(world_size / sp_size) - # Create inner ring groups - for group_id in range(groups): - for i in range(inner_ring_size): - ranks = list(range(i + (group_id * sp_size), (1 + group_id) * sp_size, inner_ring_size)) - group = dist.new_group(ranks) + if tp_size > 1: + for group_id in range(groups): + for i in range(inner_ring_size): + ranks = list(range(i + (group_id * sp_size), (1 + group_id) * sp_size, inner_ring_size)) + group = dist.new_group(ranks) + if rank in ranks: + inner_ring_group = group + for group_id in range(groups): + for i in range(num_rings): + ranks = list(range(i + group_id * num_rings, world_size, sp_size)) + group = dist.new_group(ranks) + if rank in ranks: + inter_ring_group = group + else: + for i in range(sp_size // 2): + ranks = list(range((i) * num_rings, (i + 1) * num_rings, 1)) if rank in ranks: + print( + "rank:", + rank, + "inner ranks:", + ranks, + ) + group = dist.new_group(ranks) inner_ring_group = group - - # Create inter ring groups - for group_id in range(groups): - for i in range(num_rings): - ranks = list(range(i + group_id * num_rings, world_size, sp_size)) - group = dist.new_group(ranks) - if rank in ranks: - inter_ring_group = group + for group_id in range(num_rings): + for i in range(num_rings): + ranks = list(range(i + group_id * num_rings, world_size, inner_ring_size)) + ranks = [0, 1, 4, 5] if rank == 0 or rank == 1 or rank == 4 or rank == 5 else [2, 3, 6, 7] + if rank in ranks: + group = dist.new_group(ranks) + inter_ring_group = group return inner_ring_group, inter_ring_group