[Ring Attention] Improve comments (#6085)

* improve comments

* improve comments

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
pull/6092/head
Wenxuan Tan 2024-10-15 22:23:35 -05:00 committed by GitHub
parent dcd41d0973
commit 62c13e7969
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 80 additions and 51 deletions

View File

@ -422,13 +422,18 @@ class RingAttention(torch.autograd.Function):
ATTN_DONE: torch.cuda.Event = None
SP_STREAM: torch.cuda.Stream = None
SP_GROUP: dist.ProcessGroup = None
# duplicate process group for concurrent NCCL streams
# while both PyTorch and NCCL warns(https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7)
# against this, in practice it seems to work fine.
# NOTE: Duplicating PGs for concurrent NCCL streams is a risky hack -- while it may increase throughput,
# both PyTorch and NCCL warn against this. (https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7)
# LoongTrain's original double ring impl. uses concurrent PGs
# (https://github.com/InternLM/InternEvo/blob/e52f2ffc9acf818e8f2b1f97dfc69ceb2f06e154/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py#L192)
# but I confirmed with Pytorch developers this can cause obscure "Software caused connection abort" errors.
# (https://github.com/pytorch/pytorch/issues/132852)
# NOTE: In general, a smarter idea is put as many P2P calls as possible into one `batch_isend_irecv`.
INNER_RING_GROUP: dist.ProcessGroup = None
INNER_RING_GROUP_COPY: dist.ProcessGroup = None
# INNER_RING_GROUP_COPY: dist.ProcessGroup = None
INTER_RING_GROUP: dist.ProcessGroup = None
INTER_RING_GROUP_COPY: dist.ProcessGroup = None
# INTER_RING_GROUP_COPY: dist.ProcessGroup = None
@staticmethod
def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None):
@ -626,7 +631,13 @@ class RingAttention(torch.autograd.Function):
inner_ring_group: Optional[dist.ProcessGroup] = None,
inter_ring_group: Optional[dist.ProcessGroup] = None,
):
"""
Forward supporting both packed (varlen) and batched(fixed length, no padding) sequences.
No separate version for batched seq (hard to maintain), which incurs
some overhead in sequence splitting due to python for loops.
Uses two CUDA streams to overlap softmax denominator correction with next flash attn
(see comments below).
"""
cu_seqlens_q = cu_seqlens_kv = cu_seqlens
max_seqlen_q = max_seqlen_kv = max_seqlen
cu_seqlens_half = cu_seqlens // 2
@ -668,7 +679,8 @@ class RingAttention(torch.autograd.Function):
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
# Attempt to achieve concurrent comm in the two-stream forward
# Create communicators corresponding to two CUDA streams
local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)]
inter_ring_comm = RingComm(inter_ring_group)
local_sp_size = dist.get_world_size(inner_ring_group)
@ -676,7 +688,7 @@ class RingAttention(torch.autograd.Function):
inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0
num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1
# Non-contiguous indexing copies to a new contiguous tensor,
# Any type of indexing(but not slicing) copies to a new contiguous tensor,
# so only do it once
if sp_rank != sp_size - 1:
q1 = q[half_idx_back]
@ -693,6 +705,7 @@ class RingAttention(torch.autograd.Function):
rng_states = [None for _ in range(sp_size)]
sp_streams = [torch.cuda.current_stream(), sp_stream]
# Helper to pass args to FA
def _forward(q, k, v, causal):
(
_,
@ -723,6 +736,7 @@ class RingAttention(torch.autograd.Function):
if i < local_sp_size - 1:
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
# Forward within a node
def _local_ring_forward():
# (Hopefully) overlap output correction with next flash attn
for i in range(local_sp_size):
@ -731,6 +745,8 @@ class RingAttention(torch.autograd.Function):
# NOTE: waiting outside the current stream will NOT correctly synchronize.
if i > 0:
local_kv_comms[(i + 1) % 2].wait()
# Prefetch
if i == 0:
_kv_comm(i)
@ -764,15 +780,22 @@ class RingAttention(torch.autograd.Function):
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record()
# Pipeline the next KV comm with output correction instead of the next flash attn
# to minimize idle time when comm takes longer than attn.
# kernel, to minimize bubble when comm takes longer than attn.
_kv_comm(i + 1)
block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
) # (H, T) -> (T, H, 1)
assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
# Output and log sum exp correction. Ideally overlap this with the next flash attn kernel.
# In reality this always finishes before next flash attn; no need for extra sync.
# Output and log sum exp correction.
# Ideally overlap this with the next flash attn kernel,
# since attn uses Tensor Core and rescale is element-wise, memory-bound and uses CUDA cores.
# (NOTE that this is the same as ping-pong scheduling idea in FA3)
# TODO However sometimes while the GPU has scheduled the next kernel,
# it's reluctant to launch it in overlap. Some potential causes:
# 1. need lower-level CUDA scheduling 2. further benchmark against Megatron-LM
# 3. register spilling by FA kernel.
if i == 0:
out = block_out[0]
softmax_lse = block_softmax_lse[0]
@ -788,15 +811,17 @@ class RingAttention(torch.autograd.Function):
torch.cuda.current_stream().wait_stream(sp_stream)
return out, softmax_lse
# Forward for inter-node (the outer ring in 2D ring)
def _other_ring_forward(ring_num_idx, out, softmax_lse):
# Loop through the inner ring after receiving
# all new KVs from the previous inner ring
# all new KVs from another ring
for i in range(local_sp_size):
with torch.cuda.stream(sp_streams[i % 2]):
# Send & recv KV
if i > 0:
local_kv_comms[(i + 1) % 2].wait()
# Prefetch
if i == 0:
_kv_comm(i)
@ -893,7 +918,8 @@ class RingAttention(torch.autograd.Function):
def backward(ctx, dout, _):
"""
During backward, we accumulate q grads on each rank locally, but iterate kv and their grads
over all ranks for accumulation.
over all ranks for accumulation. We avoid using two streams due to backward using doubled
buffers and more comm cost.
"""
(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9]
rng_states = ctx.saved_tensors[9:]
@ -925,7 +951,7 @@ class RingAttention(torch.autograd.Function):
local_sp_rank = dist.get_rank(sp_group)
sp_size = dist.get_world_size(sp_group)
# Using separate streams (pg) for concurrent kv and dkv comm may
# NOTE: Using separate streams (PG) for concurrent kv and dkv comm may
# cause NCCL "software caused connection abort" here...
local_kv_comm = RingComm(local_kv_group)
local_dkv_comm = RingComm(local_kv_group)
@ -957,6 +983,7 @@ class RingAttention(torch.autograd.Function):
dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D)
del k, v
# Helper to pass args to FA
def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal):
_flash_attn_backward(
dout,
@ -977,8 +1004,7 @@ class RingAttention(torch.autograd.Function):
**misc_kwargs,
)
# NOTE: We avoid using two streams due to doubled buffers
# and that backward is more communication intensive.
# Backward within a node
def _local_ring_backward():
for i in range(local_sp_size):
if i > 0:
@ -1041,6 +1067,7 @@ class RingAttention(torch.autograd.Function):
dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
return dq, dkv_recv, dkv_send
# Backward for inter-node (the outer ring in 2D ring)
def _other_ring_backward(ring_num_idx, dq):
if ring_num_idx > inter_ring_rank:
# Indexing is expensive
@ -1125,34 +1152,34 @@ class RingAttention(torch.autograd.Function):
@staticmethod
def prepare_varlen_batch(
attention_mask: torch.Tensor,
padding_mask: torch.Tensor,
sp_group: dist.ProcessGroup,
inputs_embeds: torch.Tensor = None,
position_ids: Optional[torch.Tensor] = None,
is_label: bool = False,
is_2d: bool = True,
is_batched_seq: bool = True,
):
# TODO: support setting a batch dim (fix packing length) for packed mode, so that
# DP can be used (needs to modify dataloader too)
"""
Preprocess a batch of padded sequence by splitting input sequence by sp_size
sequence-wise and packing them into one sequence. Updates the mask info accordingly.
seq-wise and packing them into one sequence. Updates the mask info accordingly.
Args:
attention_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked.
padding_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked.
sp_group (dist.ProcessGroup): Process group for sequence parallelism
inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...]
position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None.
is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first
token of each sequence.
is_2d (bool, optional): Whether to return 2D outputs padded to max_seqlen // sp_size or flatten
the batch dim to a packed 1d sequence. Contingent on model forward shape definitions.
is_batched_seq (bool, optional): If True, then the input is a batch of (potentially padded) sequences
of shape [B, Sq, ...]; else a packed sequence of shape [T, ...].
Returns:
torch.Tensor:
Packed input embeddings of shape [B, Sq // sp_size, ...].
Dict[str, Any]:
inputs_embeds (torch.Tensor):
Packed input embeddings of shape [B, Sq // sp_size, ...] if is_batched_seq, else [T, ...].
mask_info (Dict[str, Any]):
A dictionary containing mask info.
torch.Tensor:
position_ids (torch.Tensor):
Packed position ids of shape [..., Sq // sp_size].
"""
@ -1160,12 +1187,11 @@ class RingAttention(torch.autograd.Function):
sp_size = dist.get_world_size(group=sp_group)
sp_rank = dist.get_rank(group=sp_group)
mask_info = {}
mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(attention_mask, return_indices=False)
mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(padding_mask, return_indices=False)
# Unpad, split seq-wise, then pad back to (B, max_seqlen // sp_size)
# Split mask to compute local nonzero position indices
# Unpad, split seq-wise, then pad to (B, max_seqlen // sp_size)
# (B, Sq) -> (B, max_seqlen // sp_size)
attention_mask = attention_mask[:, : mask_info["max_seqlen"]]
padding_mask = padding_mask[:, : mask_info["max_seqlen"]]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]]
inputs_embeds = split_varlen_zigzag(
@ -1173,11 +1199,12 @@ class RingAttention(torch.autograd.Function):
mask_info["cu_seqlens"],
sp_group,
mask_info["max_seqlen"],
is_2d=is_2d,
is_batched_seq=is_batched_seq,
is_label=is_label,
)
attention_mask = split_varlen_zigzag(
attention_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_2d=is_2d
# Split mask to get local nonzero seq positions
padding_mask = split_varlen_zigzag(
padding_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_batched_seq=is_batched_seq
)
if position_ids is not None:
@ -1190,7 +1217,7 @@ class RingAttention(torch.autograd.Function):
)
mask_info["max_seqlen"] //= sp_size
mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
mask_info["valid_indices"] = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
mask_info["cu_seqlens"] //= sp_size
mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
return inputs_embeds, mask_info, position_ids

View File

@ -295,8 +295,8 @@ def split_batch_zigzag(
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask
in the causal setting will result in the preceding ranks having much less workload.
Split the input sequence batch . Naively spliting the attention mask in the causal setting
will result in the preceding ranks having much less workload.
We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
@ -346,40 +346,42 @@ def split_varlen_zigzag(
cu_seqlens: torch.Tensor,
sp_group: ProcessGroup,
max_seqlen: int = 0,
is_2d: bool = False,
is_batched_seq: bool = False,
is_label: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
"""Split each sequence in a batch of packed sequences in a zigzag fashion.
For each tensor in batch, return packed sequences if is_2d is False;
else return a padded batch of sequences.
"""Split a packed seq/batch of padded sequences in a Zigzag fashion.
Different from split_batch_zigzag, inputs here have variable sequence lengths.
Args:
batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d.
batch (List[torch.Tensor]): Packed sequences of shape (T, ...), or (B, Sq, ...) if is_batched_seq,
where T is the total number of tokens.
cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting.
sp_group (ProcessGroup): The process group for sequence parallelism.
max_seqlen (int): The maximum sequence length in the batch before splitting.
is_2d (bool): If True, then input has batch size and sequence length split into two dimensions.
is_batched_seq (bool): If True, then the input is a batch of sequences padded to the same len.
is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>).
Returns:
batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size)
or (B, max_seqlen // sp_size, ...) if is_2d
batch (List[torch.Tensor]): Packed sequences of shape (T, ..)
or (B, max_seqlen // sp_size, ...) if is_batched_seq
"""
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
if sp_size == 1:
return batch
if is_2d:
if is_batched_seq:
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
if isinstance(batch, torch.Tensor):
batch = [batch]
# seq: (B, Sq, h, n)
# seq = seq[:, :rank * (seqlen // sp_size), ...]
for i, packed_seq in enumerate(batch):
device = packed_seq.device
dtype = packed_seq.dtype
if is_2d:
if is_batched_seq:
assert max_seqlen % (sp_size * 2) == 0
# Recreate a padded tensor with the new max seqlen
shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
@ -398,7 +400,7 @@ def split_varlen_zigzag(
seqlen % (2 * sp_size) == 0
), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"
if is_2d:
if is_batched_seq:
seq = packed_seq[j][:seqlen]
if is_label:
# Shift one position to the right for next token prediction
@ -415,7 +417,7 @@ def split_varlen_zigzag(
seq = seq.chunk(sp_size * 2)
local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])
if is_2d:
if is_batched_seq:
batch[i] = local_seq.contiguous()
else:
batch[i] = torch.cat(local_seq, dim=0)