pull/6071/head
wangbluo 2024-10-15 14:50:27 +08:00
parent bc7eeade33
commit 83cf2f84fb
1 changed files with 2 additions and 2 deletions

View File

@ -470,14 +470,14 @@ class RingAttention(torch.autograd.Function):
# Create inner ring groups
for i in range(inner_ring_size):
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
group = pg_mesh.get_group_along_axis(2, ranks)
group = pg_mesh.get_group_along_axis(sp_axis, ranks)
if sp_rank in ranks:
inner_ring_group = group
# Create inter ring groups
for i in range(num_rings):
ranks = list(range(i, sp_size, num_rings))
group = pg_mesh.get_group_along_axis(2, ranks)
group = pg_mesh.get_group_along_axis(sp_axis, ranks)
if sp_rank in ranks:
inter_ring_group = group