[Ring Attention] Improve comments (#6085)

* improve comments

* improve comments

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
pull/6092/head
Wenxuan Tan 2024-10-15 22:23:35 -05:00 committed by GitHub
parent dcd41d0973
commit 62c13e7969
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 80 additions and 51 deletions

View File

@ -422,13 +422,18 @@ class RingAttention(torch.autograd.Function):
ATTN_DONE: torch.cuda.Event = None ATTN_DONE: torch.cuda.Event = None
SP_STREAM: torch.cuda.Stream = None SP_STREAM: torch.cuda.Stream = None
SP_GROUP: dist.ProcessGroup = None SP_GROUP: dist.ProcessGroup = None
# duplicate process group for concurrent NCCL streams
# while both PyTorch and NCCL warns(https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7) # NOTE: Duplicating PGs for concurrent NCCL streams is a risky hack -- while it may increase throughput,
# against this, in practice it seems to work fine. # both PyTorch and NCCL warn against this. (https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7)
# LoongTrain's original double ring impl. uses concurrent PGs
# (https://github.com/InternLM/InternEvo/blob/e52f2ffc9acf818e8f2b1f97dfc69ceb2f06e154/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py#L192)
# but I confirmed with Pytorch developers this can cause obscure "Software caused connection abort" errors.
# (https://github.com/pytorch/pytorch/issues/132852)
# NOTE: In general, a smarter idea is put as many P2P calls as possible into one `batch_isend_irecv`.
INNER_RING_GROUP: dist.ProcessGroup = None INNER_RING_GROUP: dist.ProcessGroup = None
INNER_RING_GROUP_COPY: dist.ProcessGroup = None # INNER_RING_GROUP_COPY: dist.ProcessGroup = None
INTER_RING_GROUP: dist.ProcessGroup = None INTER_RING_GROUP: dist.ProcessGroup = None
INTER_RING_GROUP_COPY: dist.ProcessGroup = None # INTER_RING_GROUP_COPY: dist.ProcessGroup = None
@staticmethod @staticmethod
def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None): def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None):
@ -626,7 +631,13 @@ class RingAttention(torch.autograd.Function):
inner_ring_group: Optional[dist.ProcessGroup] = None, inner_ring_group: Optional[dist.ProcessGroup] = None,
inter_ring_group: Optional[dist.ProcessGroup] = None, inter_ring_group: Optional[dist.ProcessGroup] = None,
): ):
"""
Forward supporting both packed (varlen) and batched(fixed length, no padding) sequences.
No separate version for batched seq (hard to maintain), which incurs
some overhead in sequence splitting due to python for loops.
Uses two CUDA streams to overlap softmax denominator correction with next flash attn
(see comments below).
"""
cu_seqlens_q = cu_seqlens_kv = cu_seqlens cu_seqlens_q = cu_seqlens_kv = cu_seqlens
max_seqlen_q = max_seqlen_kv = max_seqlen max_seqlen_q = max_seqlen_kv = max_seqlen
cu_seqlens_half = cu_seqlens // 2 cu_seqlens_half = cu_seqlens // 2
@ -668,7 +679,8 @@ class RingAttention(torch.autograd.Function):
sp_size = dist.get_world_size(sp_group) sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group) sp_rank = dist.get_rank(sp_group)
# Attempt to achieve concurrent comm in the two-stream forward
# Create communicators corresponding to two CUDA streams
local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)] local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)]
inter_ring_comm = RingComm(inter_ring_group) inter_ring_comm = RingComm(inter_ring_group)
local_sp_size = dist.get_world_size(inner_ring_group) local_sp_size = dist.get_world_size(inner_ring_group)
@ -676,7 +688,7 @@ class RingAttention(torch.autograd.Function):
inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0 inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0
num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1 num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1
# Non-contiguous indexing copies to a new contiguous tensor, # Any type of indexing(but not slicing) copies to a new contiguous tensor,
# so only do it once # so only do it once
if sp_rank != sp_size - 1: if sp_rank != sp_size - 1:
q1 = q[half_idx_back] q1 = q[half_idx_back]
@ -693,6 +705,7 @@ class RingAttention(torch.autograd.Function):
rng_states = [None for _ in range(sp_size)] rng_states = [None for _ in range(sp_size)]
sp_streams = [torch.cuda.current_stream(), sp_stream] sp_streams = [torch.cuda.current_stream(), sp_stream]
# Helper to pass args to FA
def _forward(q, k, v, causal): def _forward(q, k, v, causal):
( (
_, _,
@ -723,6 +736,7 @@ class RingAttention(torch.autograd.Function):
if i < local_sp_size - 1: if i < local_sp_size - 1:
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
# Forward within a node
def _local_ring_forward(): def _local_ring_forward():
# (Hopefully) overlap output correction with next flash attn # (Hopefully) overlap output correction with next flash attn
for i in range(local_sp_size): for i in range(local_sp_size):
@ -731,6 +745,8 @@ class RingAttention(torch.autograd.Function):
# NOTE: waiting outside the current stream will NOT correctly synchronize. # NOTE: waiting outside the current stream will NOT correctly synchronize.
if i > 0: if i > 0:
local_kv_comms[(i + 1) % 2].wait() local_kv_comms[(i + 1) % 2].wait()
# Prefetch
if i == 0: if i == 0:
_kv_comm(i) _kv_comm(i)
@ -764,15 +780,22 @@ class RingAttention(torch.autograd.Function):
) = _forward(q_block, kv_block[0], kv_block[1], causal=False) ) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record() RingAttention.ATTN_DONE.record()
# Pipeline the next KV comm with output correction instead of the next flash attn # Pipeline the next KV comm with output correction instead of the next flash attn
# to minimize idle time when comm takes longer than attn. # kernel, to minimize bubble when comm takes longer than attn.
_kv_comm(i + 1) _kv_comm(i + 1)
block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
) # (H, T) -> (T, H, 1) ) # (H, T) -> (T, H, 1)
assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1] assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
# Output and log sum exp correction. Ideally overlap this with the next flash attn kernel.
# In reality this always finishes before next flash attn; no need for extra sync. # Output and log sum exp correction.
# Ideally overlap this with the next flash attn kernel,
# since attn uses Tensor Core and rescale is element-wise, memory-bound and uses CUDA cores.
# (NOTE that this is the same as ping-pong scheduling idea in FA3)
# TODO However sometimes while the GPU has scheduled the next kernel,
# it's reluctant to launch it in overlap. Some potential causes:
# 1. need lower-level CUDA scheduling 2. further benchmark against Megatron-LM
# 3. register spilling by FA kernel.
if i == 0: if i == 0:
out = block_out[0] out = block_out[0]
softmax_lse = block_softmax_lse[0] softmax_lse = block_softmax_lse[0]
@ -788,15 +811,17 @@ class RingAttention(torch.autograd.Function):
torch.cuda.current_stream().wait_stream(sp_stream) torch.cuda.current_stream().wait_stream(sp_stream)
return out, softmax_lse return out, softmax_lse
# Forward for inter-node (the outer ring in 2D ring)
def _other_ring_forward(ring_num_idx, out, softmax_lse): def _other_ring_forward(ring_num_idx, out, softmax_lse):
# Loop through the inner ring after receiving # Loop through the inner ring after receiving
# all new KVs from the previous inner ring # all new KVs from another ring
for i in range(local_sp_size): for i in range(local_sp_size):
with torch.cuda.stream(sp_streams[i % 2]): with torch.cuda.stream(sp_streams[i % 2]):
# Send & recv KV # Send & recv KV
if i > 0: if i > 0:
local_kv_comms[(i + 1) % 2].wait() local_kv_comms[(i + 1) % 2].wait()
# Prefetch
if i == 0: if i == 0:
_kv_comm(i) _kv_comm(i)
@ -893,7 +918,8 @@ class RingAttention(torch.autograd.Function):
def backward(ctx, dout, _): def backward(ctx, dout, _):
""" """
During backward, we accumulate q grads on each rank locally, but iterate kv and their grads During backward, we accumulate q grads on each rank locally, but iterate kv and their grads
over all ranks for accumulation. over all ranks for accumulation. We avoid using two streams due to backward using doubled
buffers and more comm cost.
""" """
(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9] (q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9]
rng_states = ctx.saved_tensors[9:] rng_states = ctx.saved_tensors[9:]
@ -925,7 +951,7 @@ class RingAttention(torch.autograd.Function):
local_sp_rank = dist.get_rank(sp_group) local_sp_rank = dist.get_rank(sp_group)
sp_size = dist.get_world_size(sp_group) sp_size = dist.get_world_size(sp_group)
# Using separate streams (pg) for concurrent kv and dkv comm may # NOTE: Using separate streams (PG) for concurrent kv and dkv comm may
# cause NCCL "software caused connection abort" here... # cause NCCL "software caused connection abort" here...
local_kv_comm = RingComm(local_kv_group) local_kv_comm = RingComm(local_kv_group)
local_dkv_comm = RingComm(local_kv_group) local_dkv_comm = RingComm(local_kv_group)
@ -957,6 +983,7 @@ class RingAttention(torch.autograd.Function):
dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D) dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D)
del k, v del k, v
# Helper to pass args to FA
def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal): def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal):
_flash_attn_backward( _flash_attn_backward(
dout, dout,
@ -977,8 +1004,7 @@ class RingAttention(torch.autograd.Function):
**misc_kwargs, **misc_kwargs,
) )
# NOTE: We avoid using two streams due to doubled buffers # Backward within a node
# and that backward is more communication intensive.
def _local_ring_backward(): def _local_ring_backward():
for i in range(local_sp_size): for i in range(local_sp_size):
if i > 0: if i > 0:
@ -1041,6 +1067,7 @@ class RingAttention(torch.autograd.Function):
dkv_send = dkv_buffers[(local_sp_size - 1) % 2] dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
return dq, dkv_recv, dkv_send return dq, dkv_recv, dkv_send
# Backward for inter-node (the outer ring in 2D ring)
def _other_ring_backward(ring_num_idx, dq): def _other_ring_backward(ring_num_idx, dq):
if ring_num_idx > inter_ring_rank: if ring_num_idx > inter_ring_rank:
# Indexing is expensive # Indexing is expensive
@ -1125,34 +1152,34 @@ class RingAttention(torch.autograd.Function):
@staticmethod @staticmethod
def prepare_varlen_batch( def prepare_varlen_batch(
attention_mask: torch.Tensor, padding_mask: torch.Tensor,
sp_group: dist.ProcessGroup, sp_group: dist.ProcessGroup,
inputs_embeds: torch.Tensor = None, inputs_embeds: torch.Tensor = None,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
is_label: bool = False, is_label: bool = False,
is_2d: bool = True, is_batched_seq: bool = True,
): ):
# TODO: support setting a batch dim (fix packing length) for packed mode, so that
# DP can be used (needs to modify dataloader too)
""" """
Preprocess a batch of padded sequence by splitting input sequence by sp_size Preprocess a batch of padded sequence by splitting input sequence by sp_size
sequence-wise and packing them into one sequence. Updates the mask info accordingly. seq-wise and packing them into one sequence. Updates the mask info accordingly.
Args: Args:
attention_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked. padding_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked.
sp_group (dist.ProcessGroup): Process group for sequence parallelism sp_group (dist.ProcessGroup): Process group for sequence parallelism
inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...] inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...]
position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None. position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None.
is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first
token of each sequence. token of each sequence.
is_2d (bool, optional): Whether to return 2D outputs padded to max_seqlen // sp_size or flatten is_batched_seq (bool, optional): If True, then the input is a batch of (potentially padded) sequences
the batch dim to a packed 1d sequence. Contingent on model forward shape definitions. of shape [B, Sq, ...]; else a packed sequence of shape [T, ...].
Returns: Returns:
torch.Tensor: inputs_embeds (torch.Tensor):
Packed input embeddings of shape [B, Sq // sp_size, ...]. Packed input embeddings of shape [B, Sq // sp_size, ...] if is_batched_seq, else [T, ...].
mask_info (Dict[str, Any]):
Dict[str, Any]:
A dictionary containing mask info. A dictionary containing mask info.
position_ids (torch.Tensor):
torch.Tensor:
Packed position ids of shape [..., Sq // sp_size]. Packed position ids of shape [..., Sq // sp_size].
""" """
@ -1160,12 +1187,11 @@ class RingAttention(torch.autograd.Function):
sp_size = dist.get_world_size(group=sp_group) sp_size = dist.get_world_size(group=sp_group)
sp_rank = dist.get_rank(group=sp_group) sp_rank = dist.get_rank(group=sp_group)
mask_info = {} mask_info = {}
mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(attention_mask, return_indices=False) mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(padding_mask, return_indices=False)
# Unpad, split seq-wise, then pad back to (B, max_seqlen // sp_size) # Unpad, split seq-wise, then pad to (B, max_seqlen // sp_size)
# Split mask to compute local nonzero position indices
# (B, Sq) -> (B, max_seqlen // sp_size) # (B, Sq) -> (B, max_seqlen // sp_size)
attention_mask = attention_mask[:, : mask_info["max_seqlen"]] padding_mask = padding_mask[:, : mask_info["max_seqlen"]]
if inputs_embeds is not None: if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]] inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]]
inputs_embeds = split_varlen_zigzag( inputs_embeds = split_varlen_zigzag(
@ -1173,11 +1199,12 @@ class RingAttention(torch.autograd.Function):
mask_info["cu_seqlens"], mask_info["cu_seqlens"],
sp_group, sp_group,
mask_info["max_seqlen"], mask_info["max_seqlen"],
is_2d=is_2d, is_batched_seq=is_batched_seq,
is_label=is_label, is_label=is_label,
) )
attention_mask = split_varlen_zigzag( # Split mask to get local nonzero seq positions
attention_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_2d=is_2d padding_mask = split_varlen_zigzag(
padding_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_batched_seq=is_batched_seq
) )
if position_ids is not None: if position_ids is not None:
@ -1190,7 +1217,7 @@ class RingAttention(torch.autograd.Function):
) )
mask_info["max_seqlen"] //= sp_size mask_info["max_seqlen"] //= sp_size
mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() mask_info["valid_indices"] = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
mask_info["cu_seqlens"] //= sp_size mask_info["cu_seqlens"] //= sp_size
mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
return inputs_embeds, mask_info, position_ids return inputs_embeds, mask_info, position_ids

View File

@ -295,8 +295,8 @@ def split_batch_zigzag(
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, List[torch.Tensor]]:
""" """
Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask Split the input sequence batch . Naively spliting the attention mask in the causal setting
in the causal setting will result in the preceding ranks having much less workload. will result in the preceding ranks having much less workload.
We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2). We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |. For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
@ -346,40 +346,42 @@ def split_varlen_zigzag(
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
sp_group: ProcessGroup, sp_group: ProcessGroup,
max_seqlen: int = 0, max_seqlen: int = 0,
is_2d: bool = False, is_batched_seq: bool = False,
is_label: bool = False, is_label: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]: ) -> Union[List[torch.Tensor], torch.Tensor]:
"""Split each sequence in a batch of packed sequences in a zigzag fashion. """Split a packed seq/batch of padded sequences in a Zigzag fashion.
For each tensor in batch, return packed sequences if is_2d is False; Different from split_batch_zigzag, inputs here have variable sequence lengths.
else return a padded batch of sequences.
Args: Args:
batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d. batch (List[torch.Tensor]): Packed sequences of shape (T, ...), or (B, Sq, ...) if is_batched_seq,
where T is the total number of tokens.
cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting. cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting.
sp_group (ProcessGroup): The process group for sequence parallelism. sp_group (ProcessGroup): The process group for sequence parallelism.
max_seqlen (int): The maximum sequence length in the batch before splitting. max_seqlen (int): The maximum sequence length in the batch before splitting.
is_2d (bool): If True, then input has batch size and sequence length split into two dimensions. is_batched_seq (bool): If True, then the input is a batch of sequences padded to the same len.
is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>). is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>).
Returns: Returns:
batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size) batch (List[torch.Tensor]): Packed sequences of shape (T, ..)
or (B, max_seqlen // sp_size, ...) if is_2d or (B, max_seqlen // sp_size, ...) if is_batched_seq
""" """
sp_size = dist.get_world_size(sp_group) sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group) sp_rank = dist.get_rank(sp_group)
if sp_size == 1: if sp_size == 1:
return batch return batch
if is_2d: if is_batched_seq:
assert max_seqlen > 0, "max_seqlen must be provided for 2D input" assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
if isinstance(batch, torch.Tensor): if isinstance(batch, torch.Tensor):
batch = [batch] batch = [batch]
# seq: (B, Sq, h, n)
# seq = seq[:, :rank * (seqlen // sp_size), ...]
for i, packed_seq in enumerate(batch): for i, packed_seq in enumerate(batch):
device = packed_seq.device device = packed_seq.device
dtype = packed_seq.dtype dtype = packed_seq.dtype
if is_2d: if is_batched_seq:
assert max_seqlen % (sp_size * 2) == 0 assert max_seqlen % (sp_size * 2) == 0
# Recreate a padded tensor with the new max seqlen # Recreate a padded tensor with the new max seqlen
shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:]) shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
@ -398,7 +400,7 @@ def split_varlen_zigzag(
seqlen % (2 * sp_size) == 0 seqlen % (2 * sp_size) == 0
), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting" ), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"
if is_2d: if is_batched_seq:
seq = packed_seq[j][:seqlen] seq = packed_seq[j][:seqlen]
if is_label: if is_label:
# Shift one position to the right for next token prediction # Shift one position to the right for next token prediction
@ -415,7 +417,7 @@ def split_varlen_zigzag(
seq = seq.chunk(sp_size * 2) seq = seq.chunk(sp_size * 2)
local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]]) local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])
if is_2d: if is_batched_seq:
batch[i] = local_seq.contiguous() batch[i] = local_seq.contiguous()
else: else:
batch[i] = torch.cat(local_seq, dim=0) batch[i] = torch.cat(local_seq, dim=0)