|
|
@ -478,6 +478,7 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
num_ring_size = world_size // num_rings
|
|
|
|
num_ring_size = world_size // num_rings
|
|
|
|
|
|
|
|
|
|
|
|
if tp_size > 1:
|
|
|
|
if tp_size > 1:
|
|
|
|
|
|
|
|
# Create inner ring groups
|
|
|
|
ranks = []
|
|
|
|
ranks = []
|
|
|
|
for i in range(num_rings):
|
|
|
|
for i in range(num_rings):
|
|
|
|
start = i * num_ring_size
|
|
|
|
start = i * num_ring_size
|
|
|
@ -494,7 +495,7 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
group = dist.new_group(inner_rank)
|
|
|
|
group = dist.new_group(inner_rank)
|
|
|
|
if rank in inner_rank:
|
|
|
|
if rank in inner_rank:
|
|
|
|
inner_ring_group = group
|
|
|
|
inner_ring_group = group
|
|
|
|
|
|
|
|
# Create inter ring groups
|
|
|
|
for i in range(num_ring_size):
|
|
|
|
for i in range(num_ring_size):
|
|
|
|
inter_rank = [i + j * num_ring_size for j in range(num_rings)]
|
|
|
|
inter_rank = [i + j * num_ring_size for j in range(num_rings)]
|
|
|
|
group = dist.new_group(inter_rank)
|
|
|
|
group = dist.new_group(inter_rank)
|
|
|
|