mirror of https://github.com/hpcaitech/ColossalAI
fix
parent
efe3042bb2
commit
0002ae5956
|
@ -476,24 +476,31 @@ class RingAttention(torch.autograd.Function):
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
|
|
||||||
num_ring_size = world_size // num_rings
|
num_ring_size = world_size // num_rings
|
||||||
num_inner_group = num_ring_size // inner_ring_size
|
|
||||||
|
|
||||||
if tp_size > 1:
|
if tp_size > 1:
|
||||||
|
ranks = []
|
||||||
for i in range(num_rings):
|
for i in range(num_rings):
|
||||||
for j in range(num_inner_group):
|
start = i * num_ring_size
|
||||||
# find inner ring group in one sp groups
|
end = (i + 1) * num_ring_size
|
||||||
start = j + i * num_ring_size
|
for idx in range(start, end):
|
||||||
ranks = list(range(start, start + tp_size * inner_ring_size, tp_size))
|
inner_rank = []
|
||||||
group = dist.new_group(ranks)
|
for k in range(inner_ring_size):
|
||||||
if rank in ranks:
|
current_num = idx + k * tp_size
|
||||||
inner_ring_group = group
|
if current_num >= end:
|
||||||
for i in range(num_rings):
|
break
|
||||||
for j in range(num_inner_group):
|
inner_rank.append(current_num)
|
||||||
start = j + (i * num_inner_group)
|
if len(inner_rank) == inner_ring_size and inner_rank not in ranks:
|
||||||
ranks = list(range(start, start + num_ring_size + 1, num_ring_size))
|
ranks.append(inner_rank)
|
||||||
group = dist.new_group(ranks)
|
group = dist.new_group(inner_rank)
|
||||||
if rank in ranks:
|
if rank in inner_rank:
|
||||||
inter_ring_group = group
|
inner_ring_group = group
|
||||||
|
|
||||||
|
for i in range(num_ring_size):
|
||||||
|
inter_rank = [i + j * num_ring_size for j in range(num_rings)]
|
||||||
|
group = dist.new_group(inter_rank)
|
||||||
|
if rank in inter_rank:
|
||||||
|
inter_ring_group = group
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Create inner ring groups
|
# Create inner ring groups
|
||||||
for i in range(inner_ring_size):
|
for i in range(inner_ring_size):
|
||||||
|
|
Loading…
Reference in New Issue