From 3532f77b90ccec4e94e3d8dc2b2e0ac76008cab9 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 9 Oct 2024 10:57:19 +0800 Subject: [PATCH] fix --- colossalai/shardformer/layer/attn.py | 49 ++++++++++++++-------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 21dbe718e..36a4ec963 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -443,6 +443,7 @@ class RingAttention(torch.autograd.Function): """ sp_size = dist.get_world_size(sp_group) tp_size = dist.get_world_size(tp_group) + sp_rank = dist.get_rank(sp_group) if inner_ring_size is None: if torch.cuda.device_count() >= dist.get_world_size(): @@ -467,46 +468,44 @@ class RingAttention(torch.autograd.Function): f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!", ranks=[0], ) - num_rings = sp_size // inner_ring_size + inner_ring_group = None inter_ring_group = None world_size = dist.get_world_size() rank = dist.get_rank() - groups = int(world_size / sp_size) + + inner_rings = world_size // sp_size + num_rings = sp_size // inner_ring_size if tp_size > 1: - for group_id in range(groups): - for i in range(inner_ring_size): - ranks = list(range(i + (group_id * sp_size), (1 + group_id) * sp_size, inner_ring_size)) + for i in range(inner_rings): + for j in range(sp_size // tp_size): + # find inner ring group in one sp group + ranks = list(range(j + i * sp_size, j + (i + 1) * sp_size, tp_size)) group = dist.new_group(ranks) if rank in ranks: inner_ring_group = group - for group_id in range(groups): - for i in range(num_rings): - ranks = list(range(i + group_id * num_rings, world_size, sp_size)) + for i in range(inner_rings): + for j in range(sp_size // tp_size): + ranks = list(range(j + i * (sp_size // tp_size), inner_rings + (i + 1) * sp_size, sp_size)) group = dist.new_group(ranks) if rank in ranks: inter_ring_group = group else: - for i in range(sp_size // 2): - ranks = list(range((i) * num_rings, (i + 1) * num_rings, 1)) - if rank in ranks: - print( - "rank:", - rank, - "inner ranks:", - ranks, - ) - group = dist.new_group(ranks) + # Create inner ring groups + for i in range(inner_ring_size): + ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) + group = dist.new_group(ranks) + if sp_rank in ranks: inner_ring_group = group - for group_id in range(num_rings): - for i in range(num_rings): - ranks = list(range(i + group_id * num_rings, world_size, inner_ring_size)) - ranks = [0, 1, 4, 5] if rank == 0 or rank == 1 or rank == 4 or rank == 5 else [2, 3, 6, 7] - if rank in ranks: - group = dist.new_group(ranks) - inter_ring_group = group + + # Create inter ring groups + for i in range(num_rings): + ranks = list(range(i, sp_size, num_rings)) + group = dist.new_group(ranks) + if sp_rank in ranks: + inter_ring_group = group return inner_ring_group, inter_ring_group