|
|
|
@ -431,7 +431,7 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
INTER_RING_GROUP_COPY: dist.ProcessGroup = None |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def get_double_ring_groups(sp_group, inner_ring_size=None): |
|
|
|
|
def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None): |
|
|
|
|
""" |
|
|
|
|
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size |
|
|
|
|
shouldn't be larger than the number of NICs on each node. |
|
|
|
@ -441,21 +441,17 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
Returns: |
|
|
|
|
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. |
|
|
|
|
""" |
|
|
|
|
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization." |
|
|
|
|
|
|
|
|
|
sp_group = pg_mesh.get_group_along_axis(sp_axis) |
|
|
|
|
sp_size = dist.get_world_size(sp_group) |
|
|
|
|
sp_rank = dist.get_rank(sp_group) |
|
|
|
|
|
|
|
|
|
if inner_ring_size is None: |
|
|
|
|
if torch.cuda.device_count() >= dist.get_world_size(): |
|
|
|
|
# single node, no need to consider NICs |
|
|
|
|
return sp_group, sp_group |
|
|
|
|
if sp_size <= 4: |
|
|
|
|
inner_ring_size = min(2, sp_size) |
|
|
|
|
else: |
|
|
|
|
inner_ring_size = min(4, sp_size) |
|
|
|
|
else: |
|
|
|
|
assert ( |
|
|
|
|
inner_ring_size <= sp_size and sp_size % inner_ring_size == 0 |
|
|
|
|
), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}" |
|
|
|
|
assert inner_ring_size is not None |
|
|
|
|
|
|
|
|
|
assert ( |
|
|
|
|
inner_ring_size <= sp_size and sp_size % inner_ring_size == 0 |
|
|
|
|
), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}" |
|
|
|
|
|
|
|
|
|
if inner_ring_size == sp_size: |
|
|
|
|
return sp_group, sp_group |
|
|
|
@ -474,14 +470,14 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
# Create inner ring groups |
|
|
|
|
for i in range(inner_ring_size): |
|
|
|
|
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) |
|
|
|
|
group = dist.new_group(ranks) |
|
|
|
|
group = pg_mesh.get_group_along_axis(sp_axis, ranks) |
|
|
|
|
if sp_rank in ranks: |
|
|
|
|
inner_ring_group = group |
|
|
|
|
|
|
|
|
|
# Create inter ring groups |
|
|
|
|
for i in range(num_rings): |
|
|
|
|
ranks = list(range(i, sp_size, num_rings)) |
|
|
|
|
group = dist.new_group(ranks) |
|
|
|
|
group = pg_mesh.get_group_along_axis(sp_axis, ranks) |
|
|
|
|
if sp_rank in ranks: |
|
|
|
|
inter_ring_group = group |
|
|
|
|
|
|
|
|
@ -492,7 +488,7 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
q, # (B, H, Sq, D) |
|
|
|
|
k, |
|
|
|
|
v, |
|
|
|
|
sp_group, |
|
|
|
|
sp_axis, |
|
|
|
|
attention_mask_type, |
|
|
|
|
cu_seqlens=None, |
|
|
|
|
max_seqlen=None, |
|
|
|
@ -502,6 +498,7 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
deterministic=False, |
|
|
|
|
return_softmax=False, |
|
|
|
|
inner_ring_size=None, |
|
|
|
|
pg_mesh=None, |
|
|
|
|
**kwargs, |
|
|
|
|
): |
|
|
|
|
""" |
|
|
|
@ -512,7 +509,7 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D] |
|
|
|
|
k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D] |
|
|
|
|
v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D] |
|
|
|
|
sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism |
|
|
|
|
sp_axis (Optional[int]): Sp axis for the global pg mesh. |
|
|
|
|
sp_tream (torch.cuda.Stream): An different stream for output correction. |
|
|
|
|
cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths |
|
|
|
|
of the sequences in the batch, used to index into q. |
|
|
|
@ -537,7 +534,6 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
RingAttention.ATTN_DONE = torch.cuda.Event() |
|
|
|
|
if RingAttention.SP_STREAM is None: |
|
|
|
|
RingAttention.SP_STREAM = torch.cuda.Stream() |
|
|
|
|
|
|
|
|
|
assert ( |
|
|
|
|
q.shape[2] == k.shape[2] |
|
|
|
|
), "Q, K and V having different sequence lengths (inference or cross-attn)\ |
|
|
|
@ -546,11 +542,13 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES |
|
|
|
|
), f"Mask type {attention_mask_type} is not supported yet." |
|
|
|
|
|
|
|
|
|
clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg)) |
|
|
|
|
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization." |
|
|
|
|
|
|
|
|
|
if RingAttention.SP_GROUP is not sp_group: |
|
|
|
|
clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg)) |
|
|
|
|
sp_group = pg_mesh.get_group_along_axis(sp_axis) |
|
|
|
|
if inner_ring_size != None: |
|
|
|
|
RingAttention.SP_GROUP = sp_group |
|
|
|
|
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size) |
|
|
|
|
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size) |
|
|
|
|
RingAttention.INNER_RING_GROUP = inner_ring_group |
|
|
|
|
RingAttention.INTER_RING_GROUP = inter_ring_group |
|
|
|
|
else: |
|
|
|
|