|
|
|
@ -443,6 +443,7 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
""" |
|
|
|
|
sp_size = dist.get_world_size(sp_group) |
|
|
|
|
tp_size = dist.get_world_size(tp_group) |
|
|
|
|
sp_rank = dist.get_rank(sp_group) |
|
|
|
|
|
|
|
|
|
if inner_ring_size is None: |
|
|
|
|
if torch.cuda.device_count() >= dist.get_world_size(): |
|
|
|
@ -467,46 +468,44 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!", |
|
|
|
|
ranks=[0], |
|
|
|
|
) |
|
|
|
|
num_rings = sp_size // inner_ring_size |
|
|
|
|
|
|
|
|
|
inner_ring_group = None |
|
|
|
|
inter_ring_group = None |
|
|
|
|
|
|
|
|
|
world_size = dist.get_world_size() |
|
|
|
|
rank = dist.get_rank() |
|
|
|
|
groups = int(world_size / sp_size) |
|
|
|
|
|
|
|
|
|
inner_rings = world_size // sp_size |
|
|
|
|
num_rings = sp_size // inner_ring_size |
|
|
|
|
|
|
|
|
|
if tp_size > 1: |
|
|
|
|
for group_id in range(groups): |
|
|
|
|
for i in range(inner_ring_size): |
|
|
|
|
ranks = list(range(i + (group_id * sp_size), (1 + group_id) * sp_size, inner_ring_size)) |
|
|
|
|
for i in range(inner_rings): |
|
|
|
|
for j in range(sp_size // tp_size): |
|
|
|
|
# find inner ring group in one sp group |
|
|
|
|
ranks = list(range(j + i * sp_size, j + (i + 1) * sp_size, tp_size)) |
|
|
|
|
group = dist.new_group(ranks) |
|
|
|
|
if rank in ranks: |
|
|
|
|
inner_ring_group = group |
|
|
|
|
for group_id in range(groups): |
|
|
|
|
for i in range(num_rings): |
|
|
|
|
ranks = list(range(i + group_id * num_rings, world_size, sp_size)) |
|
|
|
|
for i in range(inner_rings): |
|
|
|
|
for j in range(sp_size // tp_size): |
|
|
|
|
ranks = list(range(j + i * (sp_size // tp_size), inner_rings + (i + 1) * sp_size, sp_size)) |
|
|
|
|
group = dist.new_group(ranks) |
|
|
|
|
if rank in ranks: |
|
|
|
|
inter_ring_group = group |
|
|
|
|
else: |
|
|
|
|
for i in range(sp_size // 2): |
|
|
|
|
ranks = list(range((i) * num_rings, (i + 1) * num_rings, 1)) |
|
|
|
|
if rank in ranks: |
|
|
|
|
print( |
|
|
|
|
"rank:", |
|
|
|
|
rank, |
|
|
|
|
"inner ranks:", |
|
|
|
|
ranks, |
|
|
|
|
) |
|
|
|
|
group = dist.new_group(ranks) |
|
|
|
|
# Create inner ring groups |
|
|
|
|
for i in range(inner_ring_size): |
|
|
|
|
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) |
|
|
|
|
group = dist.new_group(ranks) |
|
|
|
|
if sp_rank in ranks: |
|
|
|
|
inner_ring_group = group |
|
|
|
|
for group_id in range(num_rings): |
|
|
|
|
for i in range(num_rings): |
|
|
|
|
ranks = list(range(i + group_id * num_rings, world_size, inner_ring_size)) |
|
|
|
|
ranks = [0, 1, 4, 5] if rank == 0 or rank == 1 or rank == 4 or rank == 5 else [2, 3, 6, 7] |
|
|
|
|
if rank in ranks: |
|
|
|
|
group = dist.new_group(ranks) |
|
|
|
|
inter_ring_group = group |
|
|
|
|
|
|
|
|
|
# Create inter ring groups |
|
|
|
|
for i in range(num_rings): |
|
|
|
|
ranks = list(range(i, sp_size, num_rings)) |
|
|
|
|
group = dist.new_group(ranks) |
|
|
|
|
if sp_rank in ranks: |
|
|
|
|
inter_ring_group = group |
|
|
|
|
|
|
|
|
|
return inner_ring_group, inter_ring_group |
|
|
|
|
|
|
|
|
|