pull/6071/head
wangbluo 2 months ago
parent 0002ae5956
commit 1507a7528f

@ -478,6 +478,7 @@ class RingAttention(torch.autograd.Function):
num_ring_size = world_size // num_rings num_ring_size = world_size // num_rings
if tp_size > 1: if tp_size > 1:
# Create inner ring groups
ranks = [] ranks = []
for i in range(num_rings): for i in range(num_rings):
start = i * num_ring_size start = i * num_ring_size
@ -494,7 +495,7 @@ class RingAttention(torch.autograd.Function):
group = dist.new_group(inner_rank) group = dist.new_group(inner_rank)
if rank in inner_rank: if rank in inner_rank:
inner_ring_group = group inner_ring_group = group
# Create inter ring groups
for i in range(num_ring_size): for i in range(num_ring_size):
inter_rank = [i + j * num_ring_size for j in range(num_rings)] inter_rank = [i + j * num_ring_size for j in range(num_rings)]
group = dist.new_group(inter_rank) group = dist.new_group(inter_rank)

Loading…
Cancel
Save