|
|
|
@ -431,7 +431,7 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
INTER_RING_GROUP_COPY: dist.ProcessGroup = None |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def get_double_ring_groups(sp_group, pg_mesh, inner_ring_size=None): |
|
|
|
|
def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None): |
|
|
|
|
""" |
|
|
|
|
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size |
|
|
|
|
shouldn't be larger than the number of NICs on each node. |
|
|
|
@ -441,6 +441,9 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
Returns: |
|
|
|
|
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. |
|
|
|
|
""" |
|
|
|
|
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization." |
|
|
|
|
|
|
|
|
|
sp_group = pg_mesh.get_group_along_axis(sp_axis) |
|
|
|
|
sp_size = dist.get_world_size(sp_group) |
|
|
|
|
sp_rank = dist.get_rank(sp_group) |
|
|
|
|
|
|
|
|
@ -496,6 +499,7 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
return_softmax=False, |
|
|
|
|
inner_ring_size=None, |
|
|
|
|
pg_mesh=None, |
|
|
|
|
sp_axis=None, |
|
|
|
|
**kwargs, |
|
|
|
|
): |
|
|
|
|
""" |
|
|
|
@ -506,7 +510,7 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D] |
|
|
|
|
k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D] |
|
|
|
|
v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D] |
|
|
|
|
sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism |
|
|
|
|
sp_axis (Optional[int]): Sp axis for the global pg mesh. |
|
|
|
|
sp_tream (torch.cuda.Stream): An different stream for output correction. |
|
|
|
|
cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths |
|
|
|
|
of the sequences in the batch, used to index into q. |
|
|
|
@ -539,13 +543,13 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES |
|
|
|
|
), f"Mask type {attention_mask_type} is not supported yet." |
|
|
|
|
|
|
|
|
|
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization." |
|
|
|
|
|
|
|
|
|
clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg)) |
|
|
|
|
|
|
|
|
|
if inner_ring_size != None: |
|
|
|
|
RingAttention.SP_GROUP = sp_group |
|
|
|
|
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups( |
|
|
|
|
sp_group, pg_mesh, inner_ring_size |
|
|
|
|
) |
|
|
|
|
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size) |
|
|
|
|
RingAttention.INNER_RING_GROUP = inner_ring_group |
|
|
|
|
RingAttention.INTER_RING_GROUP = inter_ring_group |
|
|
|
|
else: |
|
|
|
|