mirror of https://github.com/hpcaitech/ColossalAI
[Ring Attention] Improve comments (#6085)
* improve comments * improve comments --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu>pull/6092/head
parent
dcd41d0973
commit
62c13e7969
|
@ -422,13 +422,18 @@ class RingAttention(torch.autograd.Function):
|
|||
ATTN_DONE: torch.cuda.Event = None
|
||||
SP_STREAM: torch.cuda.Stream = None
|
||||
SP_GROUP: dist.ProcessGroup = None
|
||||
# duplicate process group for concurrent NCCL streams
|
||||
# while both PyTorch and NCCL warns(https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7)
|
||||
# against this, in practice it seems to work fine.
|
||||
|
||||
# NOTE: Duplicating PGs for concurrent NCCL streams is a risky hack -- while it may increase throughput,
|
||||
# both PyTorch and NCCL warn against this. (https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7)
|
||||
# LoongTrain's original double ring impl. uses concurrent PGs
|
||||
# (https://github.com/InternLM/InternEvo/blob/e52f2ffc9acf818e8f2b1f97dfc69ceb2f06e154/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py#L192)
|
||||
# but I confirmed with Pytorch developers this can cause obscure "Software caused connection abort" errors.
|
||||
# (https://github.com/pytorch/pytorch/issues/132852)
|
||||
# NOTE: In general, a smarter idea is put as many P2P calls as possible into one `batch_isend_irecv`.
|
||||
INNER_RING_GROUP: dist.ProcessGroup = None
|
||||
INNER_RING_GROUP_COPY: dist.ProcessGroup = None
|
||||
# INNER_RING_GROUP_COPY: dist.ProcessGroup = None
|
||||
INTER_RING_GROUP: dist.ProcessGroup = None
|
||||
INTER_RING_GROUP_COPY: dist.ProcessGroup = None
|
||||
# INTER_RING_GROUP_COPY: dist.ProcessGroup = None
|
||||
|
||||
@staticmethod
|
||||
def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None):
|
||||
|
@ -626,7 +631,13 @@ class RingAttention(torch.autograd.Function):
|
|||
inner_ring_group: Optional[dist.ProcessGroup] = None,
|
||||
inter_ring_group: Optional[dist.ProcessGroup] = None,
|
||||
):
|
||||
|
||||
"""
|
||||
Forward supporting both packed (varlen) and batched(fixed length, no padding) sequences.
|
||||
No separate version for batched seq (hard to maintain), which incurs
|
||||
some overhead in sequence splitting due to python for loops.
|
||||
Uses two CUDA streams to overlap softmax denominator correction with next flash attn
|
||||
(see comments below).
|
||||
"""
|
||||
cu_seqlens_q = cu_seqlens_kv = cu_seqlens
|
||||
max_seqlen_q = max_seqlen_kv = max_seqlen
|
||||
cu_seqlens_half = cu_seqlens // 2
|
||||
|
@ -668,7 +679,8 @@ class RingAttention(torch.autograd.Function):
|
|||
|
||||
sp_size = dist.get_world_size(sp_group)
|
||||
sp_rank = dist.get_rank(sp_group)
|
||||
# Attempt to achieve concurrent comm in the two-stream forward
|
||||
|
||||
# Create communicators corresponding to two CUDA streams
|
||||
local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)]
|
||||
inter_ring_comm = RingComm(inter_ring_group)
|
||||
local_sp_size = dist.get_world_size(inner_ring_group)
|
||||
|
@ -676,7 +688,7 @@ class RingAttention(torch.autograd.Function):
|
|||
inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0
|
||||
num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1
|
||||
|
||||
# Non-contiguous indexing copies to a new contiguous tensor,
|
||||
# Any type of indexing(but not slicing) copies to a new contiguous tensor,
|
||||
# so only do it once
|
||||
if sp_rank != sp_size - 1:
|
||||
q1 = q[half_idx_back]
|
||||
|
@ -693,6 +705,7 @@ class RingAttention(torch.autograd.Function):
|
|||
rng_states = [None for _ in range(sp_size)]
|
||||
sp_streams = [torch.cuda.current_stream(), sp_stream]
|
||||
|
||||
# Helper to pass args to FA
|
||||
def _forward(q, k, v, causal):
|
||||
(
|
||||
_,
|
||||
|
@ -723,6 +736,7 @@ class RingAttention(torch.autograd.Function):
|
|||
if i < local_sp_size - 1:
|
||||
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
||||
|
||||
# Forward within a node
|
||||
def _local_ring_forward():
|
||||
# (Hopefully) overlap output correction with next flash attn
|
||||
for i in range(local_sp_size):
|
||||
|
@ -731,6 +745,8 @@ class RingAttention(torch.autograd.Function):
|
|||
# NOTE: waiting outside the current stream will NOT correctly synchronize.
|
||||
if i > 0:
|
||||
local_kv_comms[(i + 1) % 2].wait()
|
||||
|
||||
# Prefetch
|
||||
if i == 0:
|
||||
_kv_comm(i)
|
||||
|
||||
|
@ -764,15 +780,22 @@ class RingAttention(torch.autograd.Function):
|
|||
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
|
||||
RingAttention.ATTN_DONE.record()
|
||||
# Pipeline the next KV comm with output correction instead of the next flash attn
|
||||
# to minimize idle time when comm takes longer than attn.
|
||||
# kernel, to minimize bubble when comm takes longer than attn.
|
||||
_kv_comm(i + 1)
|
||||
|
||||
block_softmax_lse[i % 2] = (
|
||||
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
||||
) # (H, T) -> (T, H, 1)
|
||||
assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
|
||||
# Output and log sum exp correction. Ideally overlap this with the next flash attn kernel.
|
||||
# In reality this always finishes before next flash attn; no need for extra sync.
|
||||
|
||||
# Output and log sum exp correction.
|
||||
# Ideally overlap this with the next flash attn kernel,
|
||||
# since attn uses Tensor Core and rescale is element-wise, memory-bound and uses CUDA cores.
|
||||
# (NOTE that this is the same as ping-pong scheduling idea in FA3)
|
||||
# TODO However sometimes while the GPU has scheduled the next kernel,
|
||||
# it's reluctant to launch it in overlap. Some potential causes:
|
||||
# 1. need lower-level CUDA scheduling 2. further benchmark against Megatron-LM
|
||||
# 3. register spilling by FA kernel.
|
||||
if i == 0:
|
||||
out = block_out[0]
|
||||
softmax_lse = block_softmax_lse[0]
|
||||
|
@ -788,15 +811,17 @@ class RingAttention(torch.autograd.Function):
|
|||
torch.cuda.current_stream().wait_stream(sp_stream)
|
||||
return out, softmax_lse
|
||||
|
||||
# Forward for inter-node (the outer ring in 2D ring)
|
||||
def _other_ring_forward(ring_num_idx, out, softmax_lse):
|
||||
# Loop through the inner ring after receiving
|
||||
# all new KVs from the previous inner ring
|
||||
# all new KVs from another ring
|
||||
for i in range(local_sp_size):
|
||||
with torch.cuda.stream(sp_streams[i % 2]):
|
||||
# Send & recv KV
|
||||
if i > 0:
|
||||
local_kv_comms[(i + 1) % 2].wait()
|
||||
|
||||
# Prefetch
|
||||
if i == 0:
|
||||
_kv_comm(i)
|
||||
|
||||
|
@ -893,7 +918,8 @@ class RingAttention(torch.autograd.Function):
|
|||
def backward(ctx, dout, _):
|
||||
"""
|
||||
During backward, we accumulate q grads on each rank locally, but iterate kv and their grads
|
||||
over all ranks for accumulation.
|
||||
over all ranks for accumulation. We avoid using two streams due to backward using doubled
|
||||
buffers and more comm cost.
|
||||
"""
|
||||
(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9]
|
||||
rng_states = ctx.saved_tensors[9:]
|
||||
|
@ -925,7 +951,7 @@ class RingAttention(torch.autograd.Function):
|
|||
local_sp_rank = dist.get_rank(sp_group)
|
||||
sp_size = dist.get_world_size(sp_group)
|
||||
|
||||
# Using separate streams (pg) for concurrent kv and dkv comm may
|
||||
# NOTE: Using separate streams (PG) for concurrent kv and dkv comm may
|
||||
# cause NCCL "software caused connection abort" here...
|
||||
local_kv_comm = RingComm(local_kv_group)
|
||||
local_dkv_comm = RingComm(local_kv_group)
|
||||
|
@ -957,6 +983,7 @@ class RingAttention(torch.autograd.Function):
|
|||
dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D)
|
||||
del k, v
|
||||
|
||||
# Helper to pass args to FA
|
||||
def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal):
|
||||
_flash_attn_backward(
|
||||
dout,
|
||||
|
@ -977,8 +1004,7 @@ class RingAttention(torch.autograd.Function):
|
|||
**misc_kwargs,
|
||||
)
|
||||
|
||||
# NOTE: We avoid using two streams due to doubled buffers
|
||||
# and that backward is more communication intensive.
|
||||
# Backward within a node
|
||||
def _local_ring_backward():
|
||||
for i in range(local_sp_size):
|
||||
if i > 0:
|
||||
|
@ -1041,6 +1067,7 @@ class RingAttention(torch.autograd.Function):
|
|||
dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
|
||||
return dq, dkv_recv, dkv_send
|
||||
|
||||
# Backward for inter-node (the outer ring in 2D ring)
|
||||
def _other_ring_backward(ring_num_idx, dq):
|
||||
if ring_num_idx > inter_ring_rank:
|
||||
# Indexing is expensive
|
||||
|
@ -1125,34 +1152,34 @@ class RingAttention(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def prepare_varlen_batch(
|
||||
attention_mask: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
sp_group: dist.ProcessGroup,
|
||||
inputs_embeds: torch.Tensor = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
is_label: bool = False,
|
||||
is_2d: bool = True,
|
||||
is_batched_seq: bool = True,
|
||||
):
|
||||
# TODO: support setting a batch dim (fix packing length) for packed mode, so that
|
||||
# DP can be used (needs to modify dataloader too)
|
||||
"""
|
||||
Preprocess a batch of padded sequence by splitting input sequence by sp_size
|
||||
sequence-wise and packing them into one sequence. Updates the mask info accordingly.
|
||||
seq-wise and packing them into one sequence. Updates the mask info accordingly.
|
||||
Args:
|
||||
attention_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked.
|
||||
padding_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked.
|
||||
sp_group (dist.ProcessGroup): Process group for sequence parallelism
|
||||
inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...]
|
||||
position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None.
|
||||
is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first
|
||||
token of each sequence.
|
||||
is_2d (bool, optional): Whether to return 2D outputs padded to max_seqlen // sp_size or flatten
|
||||
the batch dim to a packed 1d sequence. Contingent on model forward shape definitions.
|
||||
is_batched_seq (bool, optional): If True, then the input is a batch of (potentially padded) sequences
|
||||
of shape [B, Sq, ...]; else a packed sequence of shape [T, ...].
|
||||
|
||||
Returns:
|
||||
torch.Tensor:
|
||||
Packed input embeddings of shape [B, Sq // sp_size, ...].
|
||||
|
||||
Dict[str, Any]:
|
||||
inputs_embeds (torch.Tensor):
|
||||
Packed input embeddings of shape [B, Sq // sp_size, ...] if is_batched_seq, else [T, ...].
|
||||
mask_info (Dict[str, Any]):
|
||||
A dictionary containing mask info.
|
||||
|
||||
torch.Tensor:
|
||||
position_ids (torch.Tensor):
|
||||
Packed position ids of shape [..., Sq // sp_size].
|
||||
|
||||
"""
|
||||
|
@ -1160,12 +1187,11 @@ class RingAttention(torch.autograd.Function):
|
|||
sp_size = dist.get_world_size(group=sp_group)
|
||||
sp_rank = dist.get_rank(group=sp_group)
|
||||
mask_info = {}
|
||||
mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(attention_mask, return_indices=False)
|
||||
mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(padding_mask, return_indices=False)
|
||||
|
||||
# Unpad, split seq-wise, then pad back to (B, max_seqlen // sp_size)
|
||||
# Split mask to compute local nonzero position indices
|
||||
# Unpad, split seq-wise, then pad to (B, max_seqlen // sp_size)
|
||||
# (B, Sq) -> (B, max_seqlen // sp_size)
|
||||
attention_mask = attention_mask[:, : mask_info["max_seqlen"]]
|
||||
padding_mask = padding_mask[:, : mask_info["max_seqlen"]]
|
||||
if inputs_embeds is not None:
|
||||
inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]]
|
||||
inputs_embeds = split_varlen_zigzag(
|
||||
|
@ -1173,11 +1199,12 @@ class RingAttention(torch.autograd.Function):
|
|||
mask_info["cu_seqlens"],
|
||||
sp_group,
|
||||
mask_info["max_seqlen"],
|
||||
is_2d=is_2d,
|
||||
is_batched_seq=is_batched_seq,
|
||||
is_label=is_label,
|
||||
)
|
||||
attention_mask = split_varlen_zigzag(
|
||||
attention_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_2d=is_2d
|
||||
# Split mask to get local nonzero seq positions
|
||||
padding_mask = split_varlen_zigzag(
|
||||
padding_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_batched_seq=is_batched_seq
|
||||
)
|
||||
|
||||
if position_ids is not None:
|
||||
|
@ -1190,7 +1217,7 @@ class RingAttention(torch.autograd.Function):
|
|||
)
|
||||
|
||||
mask_info["max_seqlen"] //= sp_size
|
||||
mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
mask_info["valid_indices"] = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
||||
mask_info["cu_seqlens"] //= sp_size
|
||||
mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
||||
return inputs_embeds, mask_info, position_ids
|
||||
|
|
|
@ -295,8 +295,8 @@ def split_batch_zigzag(
|
|||
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask
|
||||
in the causal setting will result in the preceding ranks having much less workload.
|
||||
Split the input sequence batch . Naively spliting the attention mask in the causal setting
|
||||
will result in the preceding ranks having much less workload.
|
||||
We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
|
||||
For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
|
||||
|
||||
|
@ -346,40 +346,42 @@ def split_varlen_zigzag(
|
|||
cu_seqlens: torch.Tensor,
|
||||
sp_group: ProcessGroup,
|
||||
max_seqlen: int = 0,
|
||||
is_2d: bool = False,
|
||||
is_batched_seq: bool = False,
|
||||
is_label: bool = False,
|
||||
) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||
"""Split each sequence in a batch of packed sequences in a zigzag fashion.
|
||||
For each tensor in batch, return packed sequences if is_2d is False;
|
||||
else return a padded batch of sequences.
|
||||
|
||||
"""Split a packed seq/batch of padded sequences in a Zigzag fashion.
|
||||
Different from split_batch_zigzag, inputs here have variable sequence lengths.
|
||||
Args:
|
||||
batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d.
|
||||
batch (List[torch.Tensor]): Packed sequences of shape (T, ...), or (B, Sq, ...) if is_batched_seq,
|
||||
where T is the total number of tokens.
|
||||
cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting.
|
||||
sp_group (ProcessGroup): The process group for sequence parallelism.
|
||||
max_seqlen (int): The maximum sequence length in the batch before splitting.
|
||||
is_2d (bool): If True, then input has batch size and sequence length split into two dimensions.
|
||||
is_batched_seq (bool): If True, then the input is a batch of sequences padded to the same len.
|
||||
is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>).
|
||||
|
||||
Returns:
|
||||
batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size)
|
||||
or (B, max_seqlen // sp_size, ...) if is_2d
|
||||
batch (List[torch.Tensor]): Packed sequences of shape (T, ..)
|
||||
or (B, max_seqlen // sp_size, ...) if is_batched_seq
|
||||
"""
|
||||
sp_size = dist.get_world_size(sp_group)
|
||||
sp_rank = dist.get_rank(sp_group)
|
||||
if sp_size == 1:
|
||||
return batch
|
||||
|
||||
if is_2d:
|
||||
if is_batched_seq:
|
||||
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
|
||||
|
||||
if isinstance(batch, torch.Tensor):
|
||||
batch = [batch]
|
||||
# seq: (B, Sq, h, n)
|
||||
# seq = seq[:, :rank * (seqlen // sp_size), ...]
|
||||
|
||||
for i, packed_seq in enumerate(batch):
|
||||
device = packed_seq.device
|
||||
dtype = packed_seq.dtype
|
||||
|
||||
if is_2d:
|
||||
if is_batched_seq:
|
||||
assert max_seqlen % (sp_size * 2) == 0
|
||||
# Recreate a padded tensor with the new max seqlen
|
||||
shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
|
||||
|
@ -398,7 +400,7 @@ def split_varlen_zigzag(
|
|||
seqlen % (2 * sp_size) == 0
|
||||
), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"
|
||||
|
||||
if is_2d:
|
||||
if is_batched_seq:
|
||||
seq = packed_seq[j][:seqlen]
|
||||
if is_label:
|
||||
# Shift one position to the right for next token prediction
|
||||
|
@ -415,7 +417,7 @@ def split_varlen_zigzag(
|
|||
seq = seq.chunk(sp_size * 2)
|
||||
local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])
|
||||
|
||||
if is_2d:
|
||||
if is_batched_seq:
|
||||
batch[i] = local_seq.contiguous()
|
||||
else:
|
||||
batch[i] = torch.cat(local_seq, dim=0)
|
||||
|
|
Loading…
Reference in New Issue