pull/6071/head
wangbluo 2024-10-11 14:16:21 +08:00
parent efe3042bb2
commit 0002ae5956
1 changed files with 22 additions and 15 deletions

View File

@ -476,24 +476,31 @@ class RingAttention(torch.autograd.Function):
rank = dist.get_rank() rank = dist.get_rank()
num_ring_size = world_size // num_rings num_ring_size = world_size // num_rings
num_inner_group = num_ring_size // inner_ring_size
if tp_size > 1: if tp_size > 1:
ranks = []
for i in range(num_rings): for i in range(num_rings):
for j in range(num_inner_group): start = i * num_ring_size
# find inner ring group in one sp groups end = (i + 1) * num_ring_size
start = j + i * num_ring_size for idx in range(start, end):
ranks = list(range(start, start + tp_size * inner_ring_size, tp_size)) inner_rank = []
group = dist.new_group(ranks) for k in range(inner_ring_size):
if rank in ranks: current_num = idx + k * tp_size
inner_ring_group = group if current_num >= end:
for i in range(num_rings): break
for j in range(num_inner_group): inner_rank.append(current_num)
start = j + (i * num_inner_group) if len(inner_rank) == inner_ring_size and inner_rank not in ranks:
ranks = list(range(start, start + num_ring_size + 1, num_ring_size)) ranks.append(inner_rank)
group = dist.new_group(ranks) group = dist.new_group(inner_rank)
if rank in ranks: if rank in inner_rank:
inter_ring_group = group inner_ring_group = group
for i in range(num_ring_size):
inter_rank = [i + j * num_ring_size for j in range(num_rings)]
group = dist.new_group(inter_rank)
if rank in inter_rank:
inter_ring_group = group
else: else:
# Create inner ring groups # Create inner ring groups
for i in range(inner_ring_size): for i in range(inner_ring_size):