|
|
|
@ -7,6 +7,7 @@ import torch.distributed as dist
|
|
|
|
|
import torch.nn.functional as F |
|
|
|
|
from einops import rearrange |
|
|
|
|
|
|
|
|
|
from colossalai.cluster import ProcessGroupMesh |
|
|
|
|
from colossalai.kernel.kernel_loader import ( |
|
|
|
|
FlashAttentionDaoLoader, |
|
|
|
|
FlashAttentionForFloatAndCustomMaskLoader, |
|
|
|
@ -431,7 +432,7 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
INTER_RING_GROUP_COPY: dist.ProcessGroup = None |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def get_double_ring_groups(sp_group, tp_group, inner_ring_size=None): |
|
|
|
|
def get_double_ring_groups(sp_group, inner_ring_size=None): |
|
|
|
|
""" |
|
|
|
|
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size |
|
|
|
|
shouldn't be larger than the number of NICs on each node. |
|
|
|
@ -442,7 +443,6 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. |
|
|
|
|
""" |
|
|
|
|
sp_size = dist.get_world_size(sp_group) |
|
|
|
|
tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1 |
|
|
|
|
sp_rank = dist.get_rank(sp_group) |
|
|
|
|
|
|
|
|
|
assert inner_ring_size is not None |
|
|
|
@ -465,45 +465,22 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
inner_ring_group = None |
|
|
|
|
inter_ring_group = None |
|
|
|
|
|
|
|
|
|
world_size = dist.get_world_size() |
|
|
|
|
rank = dist.get_rank() |
|
|
|
|
|
|
|
|
|
num_ring_size = world_size // num_rings |
|
|
|
|
|
|
|
|
|
if tp_size > 1: |
|
|
|
|
# Create inner ring groups |
|
|
|
|
ranks = [] |
|
|
|
|
for i in range(num_rings): |
|
|
|
|
start = i * num_ring_size |
|
|
|
|
end = (i + 1) * num_ring_size |
|
|
|
|
for idx in range(start, end): |
|
|
|
|
inner_rank = [idx + k * tp_size for k in range(inner_ring_size) if idx + k * tp_size < end] |
|
|
|
|
if len(inner_rank) == inner_ring_size and inner_rank not in ranks: |
|
|
|
|
ranks.append(inner_rank) |
|
|
|
|
group = dist.new_group(inner_rank) |
|
|
|
|
if rank in inner_rank: |
|
|
|
|
inner_ring_group = group |
|
|
|
|
# Create inter ring groups |
|
|
|
|
for i in range(num_ring_size): |
|
|
|
|
inter_rank = [i + j * num_ring_size for j in range(num_rings)] |
|
|
|
|
group = dist.new_group(inter_rank) |
|
|
|
|
if rank in inter_rank: |
|
|
|
|
inter_ring_group = group |
|
|
|
|
inter_axis, inner_axis = 0, 1 |
|
|
|
|
pg_mesh = ProcessGroupMesh(num_rings, inner_ring_size) |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
# Create inner ring groups |
|
|
|
|
for i in range(inner_ring_size): |
|
|
|
|
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) |
|
|
|
|
group = dist.new_group(ranks) |
|
|
|
|
if sp_rank in ranks: |
|
|
|
|
inner_ring_group = group |
|
|
|
|
|
|
|
|
|
# Create inter ring groups |
|
|
|
|
for i in range(num_rings): |
|
|
|
|
ranks = list(range(i, sp_size, num_rings)) |
|
|
|
|
group = dist.new_group(ranks) |
|
|
|
|
if sp_rank in ranks: |
|
|
|
|
inter_ring_group = group |
|
|
|
|
# Create inner ring groups |
|
|
|
|
for i in range(inner_ring_size): |
|
|
|
|
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) |
|
|
|
|
group = pg_mesh.get_group_along_axis(inner_axis) |
|
|
|
|
if sp_rank in ranks: |
|
|
|
|
inner_ring_group = group |
|
|
|
|
|
|
|
|
|
# Create inter ring groups |
|
|
|
|
for i in range(num_rings): |
|
|
|
|
ranks = list(range(i, sp_size, num_rings)) |
|
|
|
|
group = pg_mesh.get_group_along_axis(inter_axis) |
|
|
|
|
if sp_rank in ranks: |
|
|
|
|
inter_ring_group = group |
|
|
|
|
|
|
|
|
|
return inner_ring_group, inter_ring_group |
|
|
|
|
|
|
|
|
@ -522,7 +499,6 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
deterministic=False, |
|
|
|
|
return_softmax=False, |
|
|
|
|
inner_ring_size=None, |
|
|
|
|
tp_group=None, |
|
|
|
|
**kwargs, |
|
|
|
|
): |
|
|
|
|
""" |
|
|
|
@ -570,9 +546,7 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
if inner_ring_size != None: |
|
|
|
|
RingAttention.SP_GROUP = sp_group |
|
|
|
|
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups( |
|
|
|
|
sp_group, tp_group, inner_ring_size |
|
|
|
|
) |
|
|
|
|
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size) |
|
|
|
|
RingAttention.INNER_RING_GROUP = inner_ring_group |
|
|
|
|
RingAttention.INTER_RING_GROUP = inter_ring_group |
|
|
|
|
else: |
|
|
|
@ -619,7 +593,6 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
attention_mask_type == AttnMaskType.PADDED_CAUSAL, |
|
|
|
|
inner_ring_group, |
|
|
|
|
inter_ring_group, |
|
|
|
|
tp_group, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if attention_mask_type == AttnMaskType.PADDED_CAUSAL: |
|
|
|
|