pull/6071/head
wangbluo 2024-10-09 10:57:19 +08:00
parent 3fab92166e
commit 3532f77b90
1 changed files with 24 additions and 25 deletions

View File

@ -443,6 +443,7 @@ class RingAttention(torch.autograd.Function):
""" """
sp_size = dist.get_world_size(sp_group) sp_size = dist.get_world_size(sp_group)
tp_size = dist.get_world_size(tp_group) tp_size = dist.get_world_size(tp_group)
sp_rank = dist.get_rank(sp_group)
if inner_ring_size is None: if inner_ring_size is None:
if torch.cuda.device_count() >= dist.get_world_size(): if torch.cuda.device_count() >= dist.get_world_size():
@ -467,46 +468,44 @@ class RingAttention(torch.autograd.Function):
f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!", f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!",
ranks=[0], ranks=[0],
) )
num_rings = sp_size // inner_ring_size
inner_ring_group = None inner_ring_group = None
inter_ring_group = None inter_ring_group = None
world_size = dist.get_world_size() world_size = dist.get_world_size()
rank = dist.get_rank() rank = dist.get_rank()
groups = int(world_size / sp_size)
inner_rings = world_size // sp_size
num_rings = sp_size // inner_ring_size
if tp_size > 1: if tp_size > 1:
for group_id in range(groups): for i in range(inner_rings):
for i in range(inner_ring_size): for j in range(sp_size // tp_size):
ranks = list(range(i + (group_id * sp_size), (1 + group_id) * sp_size, inner_ring_size)) # find inner ring group in one sp group
ranks = list(range(j + i * sp_size, j + (i + 1) * sp_size, tp_size))
group = dist.new_group(ranks) group = dist.new_group(ranks)
if rank in ranks: if rank in ranks:
inner_ring_group = group inner_ring_group = group
for group_id in range(groups): for i in range(inner_rings):
for i in range(num_rings): for j in range(sp_size // tp_size):
ranks = list(range(i + group_id * num_rings, world_size, sp_size)) ranks = list(range(j + i * (sp_size // tp_size), inner_rings + (i + 1) * sp_size, sp_size))
group = dist.new_group(ranks) group = dist.new_group(ranks)
if rank in ranks: if rank in ranks:
inter_ring_group = group inter_ring_group = group
else: else:
for i in range(sp_size // 2): # Create inner ring groups
ranks = list(range((i) * num_rings, (i + 1) * num_rings, 1)) for i in range(inner_ring_size):
if rank in ranks: ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
print( group = dist.new_group(ranks)
"rank:", if sp_rank in ranks:
rank,
"inner ranks:",
ranks,
)
group = dist.new_group(ranks)
inner_ring_group = group inner_ring_group = group
for group_id in range(num_rings):
for i in range(num_rings): # Create inter ring groups
ranks = list(range(i + group_id * num_rings, world_size, inner_ring_size)) for i in range(num_rings):
ranks = [0, 1, 4, 5] if rank == 0 or rank == 1 or rank == 4 or rank == 5 else [2, 3, 6, 7] ranks = list(range(i, sp_size, num_rings))
if rank in ranks: group = dist.new_group(ranks)
group = dist.new_group(ranks) if sp_rank in ranks:
inter_ring_group = group inter_ring_group = group
return inner_ring_group, inter_ring_group return inner_ring_group, inter_ring_group