mirror of https://github.com/hpcaitech/ColossalAI
[Ring Attention] Improve comments (#6085)
* improve comments * improve comments --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu>pull/6092/head
parent
dcd41d0973
commit
62c13e7969
|
@ -422,13 +422,18 @@ class RingAttention(torch.autograd.Function):
|
||||||
ATTN_DONE: torch.cuda.Event = None
|
ATTN_DONE: torch.cuda.Event = None
|
||||||
SP_STREAM: torch.cuda.Stream = None
|
SP_STREAM: torch.cuda.Stream = None
|
||||||
SP_GROUP: dist.ProcessGroup = None
|
SP_GROUP: dist.ProcessGroup = None
|
||||||
# duplicate process group for concurrent NCCL streams
|
|
||||||
# while both PyTorch and NCCL warns(https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7)
|
# NOTE: Duplicating PGs for concurrent NCCL streams is a risky hack -- while it may increase throughput,
|
||||||
# against this, in practice it seems to work fine.
|
# both PyTorch and NCCL warn against this. (https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7)
|
||||||
|
# LoongTrain's original double ring impl. uses concurrent PGs
|
||||||
|
# (https://github.com/InternLM/InternEvo/blob/e52f2ffc9acf818e8f2b1f97dfc69ceb2f06e154/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py#L192)
|
||||||
|
# but I confirmed with Pytorch developers this can cause obscure "Software caused connection abort" errors.
|
||||||
|
# (https://github.com/pytorch/pytorch/issues/132852)
|
||||||
|
# NOTE: In general, a smarter idea is put as many P2P calls as possible into one `batch_isend_irecv`.
|
||||||
INNER_RING_GROUP: dist.ProcessGroup = None
|
INNER_RING_GROUP: dist.ProcessGroup = None
|
||||||
INNER_RING_GROUP_COPY: dist.ProcessGroup = None
|
# INNER_RING_GROUP_COPY: dist.ProcessGroup = None
|
||||||
INTER_RING_GROUP: dist.ProcessGroup = None
|
INTER_RING_GROUP: dist.ProcessGroup = None
|
||||||
INTER_RING_GROUP_COPY: dist.ProcessGroup = None
|
# INTER_RING_GROUP_COPY: dist.ProcessGroup = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None):
|
def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None):
|
||||||
|
@ -626,7 +631,13 @@ class RingAttention(torch.autograd.Function):
|
||||||
inner_ring_group: Optional[dist.ProcessGroup] = None,
|
inner_ring_group: Optional[dist.ProcessGroup] = None,
|
||||||
inter_ring_group: Optional[dist.ProcessGroup] = None,
|
inter_ring_group: Optional[dist.ProcessGroup] = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Forward supporting both packed (varlen) and batched(fixed length, no padding) sequences.
|
||||||
|
No separate version for batched seq (hard to maintain), which incurs
|
||||||
|
some overhead in sequence splitting due to python for loops.
|
||||||
|
Uses two CUDA streams to overlap softmax denominator correction with next flash attn
|
||||||
|
(see comments below).
|
||||||
|
"""
|
||||||
cu_seqlens_q = cu_seqlens_kv = cu_seqlens
|
cu_seqlens_q = cu_seqlens_kv = cu_seqlens
|
||||||
max_seqlen_q = max_seqlen_kv = max_seqlen
|
max_seqlen_q = max_seqlen_kv = max_seqlen
|
||||||
cu_seqlens_half = cu_seqlens // 2
|
cu_seqlens_half = cu_seqlens // 2
|
||||||
|
@ -668,7 +679,8 @@ class RingAttention(torch.autograd.Function):
|
||||||
|
|
||||||
sp_size = dist.get_world_size(sp_group)
|
sp_size = dist.get_world_size(sp_group)
|
||||||
sp_rank = dist.get_rank(sp_group)
|
sp_rank = dist.get_rank(sp_group)
|
||||||
# Attempt to achieve concurrent comm in the two-stream forward
|
|
||||||
|
# Create communicators corresponding to two CUDA streams
|
||||||
local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)]
|
local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)]
|
||||||
inter_ring_comm = RingComm(inter_ring_group)
|
inter_ring_comm = RingComm(inter_ring_group)
|
||||||
local_sp_size = dist.get_world_size(inner_ring_group)
|
local_sp_size = dist.get_world_size(inner_ring_group)
|
||||||
|
@ -676,7 +688,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0
|
inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0
|
||||||
num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1
|
num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1
|
||||||
|
|
||||||
# Non-contiguous indexing copies to a new contiguous tensor,
|
# Any type of indexing(but not slicing) copies to a new contiguous tensor,
|
||||||
# so only do it once
|
# so only do it once
|
||||||
if sp_rank != sp_size - 1:
|
if sp_rank != sp_size - 1:
|
||||||
q1 = q[half_idx_back]
|
q1 = q[half_idx_back]
|
||||||
|
@ -693,6 +705,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
rng_states = [None for _ in range(sp_size)]
|
rng_states = [None for _ in range(sp_size)]
|
||||||
sp_streams = [torch.cuda.current_stream(), sp_stream]
|
sp_streams = [torch.cuda.current_stream(), sp_stream]
|
||||||
|
|
||||||
|
# Helper to pass args to FA
|
||||||
def _forward(q, k, v, causal):
|
def _forward(q, k, v, causal):
|
||||||
(
|
(
|
||||||
_,
|
_,
|
||||||
|
@ -723,6 +736,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
if i < local_sp_size - 1:
|
if i < local_sp_size - 1:
|
||||||
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
||||||
|
|
||||||
|
# Forward within a node
|
||||||
def _local_ring_forward():
|
def _local_ring_forward():
|
||||||
# (Hopefully) overlap output correction with next flash attn
|
# (Hopefully) overlap output correction with next flash attn
|
||||||
for i in range(local_sp_size):
|
for i in range(local_sp_size):
|
||||||
|
@ -731,6 +745,8 @@ class RingAttention(torch.autograd.Function):
|
||||||
# NOTE: waiting outside the current stream will NOT correctly synchronize.
|
# NOTE: waiting outside the current stream will NOT correctly synchronize.
|
||||||
if i > 0:
|
if i > 0:
|
||||||
local_kv_comms[(i + 1) % 2].wait()
|
local_kv_comms[(i + 1) % 2].wait()
|
||||||
|
|
||||||
|
# Prefetch
|
||||||
if i == 0:
|
if i == 0:
|
||||||
_kv_comm(i)
|
_kv_comm(i)
|
||||||
|
|
||||||
|
@ -764,15 +780,22 @@ class RingAttention(torch.autograd.Function):
|
||||||
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
|
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
|
||||||
RingAttention.ATTN_DONE.record()
|
RingAttention.ATTN_DONE.record()
|
||||||
# Pipeline the next KV comm with output correction instead of the next flash attn
|
# Pipeline the next KV comm with output correction instead of the next flash attn
|
||||||
# to minimize idle time when comm takes longer than attn.
|
# kernel, to minimize bubble when comm takes longer than attn.
|
||||||
_kv_comm(i + 1)
|
_kv_comm(i + 1)
|
||||||
|
|
||||||
block_softmax_lse[i % 2] = (
|
block_softmax_lse[i % 2] = (
|
||||||
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
||||||
) # (H, T) -> (T, H, 1)
|
) # (H, T) -> (T, H, 1)
|
||||||
assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
|
assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
|
||||||
# Output and log sum exp correction. Ideally overlap this with the next flash attn kernel.
|
|
||||||
# In reality this always finishes before next flash attn; no need for extra sync.
|
# Output and log sum exp correction.
|
||||||
|
# Ideally overlap this with the next flash attn kernel,
|
||||||
|
# since attn uses Tensor Core and rescale is element-wise, memory-bound and uses CUDA cores.
|
||||||
|
# (NOTE that this is the same as ping-pong scheduling idea in FA3)
|
||||||
|
# TODO However sometimes while the GPU has scheduled the next kernel,
|
||||||
|
# it's reluctant to launch it in overlap. Some potential causes:
|
||||||
|
# 1. need lower-level CUDA scheduling 2. further benchmark against Megatron-LM
|
||||||
|
# 3. register spilling by FA kernel.
|
||||||
if i == 0:
|
if i == 0:
|
||||||
out = block_out[0]
|
out = block_out[0]
|
||||||
softmax_lse = block_softmax_lse[0]
|
softmax_lse = block_softmax_lse[0]
|
||||||
|
@ -788,15 +811,17 @@ class RingAttention(torch.autograd.Function):
|
||||||
torch.cuda.current_stream().wait_stream(sp_stream)
|
torch.cuda.current_stream().wait_stream(sp_stream)
|
||||||
return out, softmax_lse
|
return out, softmax_lse
|
||||||
|
|
||||||
|
# Forward for inter-node (the outer ring in 2D ring)
|
||||||
def _other_ring_forward(ring_num_idx, out, softmax_lse):
|
def _other_ring_forward(ring_num_idx, out, softmax_lse):
|
||||||
# Loop through the inner ring after receiving
|
# Loop through the inner ring after receiving
|
||||||
# all new KVs from the previous inner ring
|
# all new KVs from another ring
|
||||||
for i in range(local_sp_size):
|
for i in range(local_sp_size):
|
||||||
with torch.cuda.stream(sp_streams[i % 2]):
|
with torch.cuda.stream(sp_streams[i % 2]):
|
||||||
# Send & recv KV
|
# Send & recv KV
|
||||||
if i > 0:
|
if i > 0:
|
||||||
local_kv_comms[(i + 1) % 2].wait()
|
local_kv_comms[(i + 1) % 2].wait()
|
||||||
|
|
||||||
|
# Prefetch
|
||||||
if i == 0:
|
if i == 0:
|
||||||
_kv_comm(i)
|
_kv_comm(i)
|
||||||
|
|
||||||
|
@ -893,7 +918,8 @@ class RingAttention(torch.autograd.Function):
|
||||||
def backward(ctx, dout, _):
|
def backward(ctx, dout, _):
|
||||||
"""
|
"""
|
||||||
During backward, we accumulate q grads on each rank locally, but iterate kv and their grads
|
During backward, we accumulate q grads on each rank locally, but iterate kv and their grads
|
||||||
over all ranks for accumulation.
|
over all ranks for accumulation. We avoid using two streams due to backward using doubled
|
||||||
|
buffers and more comm cost.
|
||||||
"""
|
"""
|
||||||
(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9]
|
(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9]
|
||||||
rng_states = ctx.saved_tensors[9:]
|
rng_states = ctx.saved_tensors[9:]
|
||||||
|
@ -925,7 +951,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
local_sp_rank = dist.get_rank(sp_group)
|
local_sp_rank = dist.get_rank(sp_group)
|
||||||
sp_size = dist.get_world_size(sp_group)
|
sp_size = dist.get_world_size(sp_group)
|
||||||
|
|
||||||
# Using separate streams (pg) for concurrent kv and dkv comm may
|
# NOTE: Using separate streams (PG) for concurrent kv and dkv comm may
|
||||||
# cause NCCL "software caused connection abort" here...
|
# cause NCCL "software caused connection abort" here...
|
||||||
local_kv_comm = RingComm(local_kv_group)
|
local_kv_comm = RingComm(local_kv_group)
|
||||||
local_dkv_comm = RingComm(local_kv_group)
|
local_dkv_comm = RingComm(local_kv_group)
|
||||||
|
@ -957,6 +983,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D)
|
dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D)
|
||||||
del k, v
|
del k, v
|
||||||
|
|
||||||
|
# Helper to pass args to FA
|
||||||
def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal):
|
def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal):
|
||||||
_flash_attn_backward(
|
_flash_attn_backward(
|
||||||
dout,
|
dout,
|
||||||
|
@ -977,8 +1004,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
**misc_kwargs,
|
**misc_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: We avoid using two streams due to doubled buffers
|
# Backward within a node
|
||||||
# and that backward is more communication intensive.
|
|
||||||
def _local_ring_backward():
|
def _local_ring_backward():
|
||||||
for i in range(local_sp_size):
|
for i in range(local_sp_size):
|
||||||
if i > 0:
|
if i > 0:
|
||||||
|
@ -1041,6 +1067,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
|
dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
|
||||||
return dq, dkv_recv, dkv_send
|
return dq, dkv_recv, dkv_send
|
||||||
|
|
||||||
|
# Backward for inter-node (the outer ring in 2D ring)
|
||||||
def _other_ring_backward(ring_num_idx, dq):
|
def _other_ring_backward(ring_num_idx, dq):
|
||||||
if ring_num_idx > inter_ring_rank:
|
if ring_num_idx > inter_ring_rank:
|
||||||
# Indexing is expensive
|
# Indexing is expensive
|
||||||
|
@ -1125,34 +1152,34 @@ class RingAttention(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_varlen_batch(
|
def prepare_varlen_batch(
|
||||||
attention_mask: torch.Tensor,
|
padding_mask: torch.Tensor,
|
||||||
sp_group: dist.ProcessGroup,
|
sp_group: dist.ProcessGroup,
|
||||||
inputs_embeds: torch.Tensor = None,
|
inputs_embeds: torch.Tensor = None,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
is_label: bool = False,
|
is_label: bool = False,
|
||||||
is_2d: bool = True,
|
is_batched_seq: bool = True,
|
||||||
):
|
):
|
||||||
|
# TODO: support setting a batch dim (fix packing length) for packed mode, so that
|
||||||
|
# DP can be used (needs to modify dataloader too)
|
||||||
"""
|
"""
|
||||||
Preprocess a batch of padded sequence by splitting input sequence by sp_size
|
Preprocess a batch of padded sequence by splitting input sequence by sp_size
|
||||||
sequence-wise and packing them into one sequence. Updates the mask info accordingly.
|
seq-wise and packing them into one sequence. Updates the mask info accordingly.
|
||||||
Args:
|
Args:
|
||||||
attention_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked.
|
padding_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked.
|
||||||
sp_group (dist.ProcessGroup): Process group for sequence parallelism
|
sp_group (dist.ProcessGroup): Process group for sequence parallelism
|
||||||
inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...]
|
inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...]
|
||||||
position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None.
|
position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None.
|
||||||
is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first
|
is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first
|
||||||
token of each sequence.
|
token of each sequence.
|
||||||
is_2d (bool, optional): Whether to return 2D outputs padded to max_seqlen // sp_size or flatten
|
is_batched_seq (bool, optional): If True, then the input is a batch of (potentially padded) sequences
|
||||||
the batch dim to a packed 1d sequence. Contingent on model forward shape definitions.
|
of shape [B, Sq, ...]; else a packed sequence of shape [T, ...].
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor:
|
inputs_embeds (torch.Tensor):
|
||||||
Packed input embeddings of shape [B, Sq // sp_size, ...].
|
Packed input embeddings of shape [B, Sq // sp_size, ...] if is_batched_seq, else [T, ...].
|
||||||
|
mask_info (Dict[str, Any]):
|
||||||
Dict[str, Any]:
|
|
||||||
A dictionary containing mask info.
|
A dictionary containing mask info.
|
||||||
|
position_ids (torch.Tensor):
|
||||||
torch.Tensor:
|
|
||||||
Packed position ids of shape [..., Sq // sp_size].
|
Packed position ids of shape [..., Sq // sp_size].
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -1160,12 +1187,11 @@ class RingAttention(torch.autograd.Function):
|
||||||
sp_size = dist.get_world_size(group=sp_group)
|
sp_size = dist.get_world_size(group=sp_group)
|
||||||
sp_rank = dist.get_rank(group=sp_group)
|
sp_rank = dist.get_rank(group=sp_group)
|
||||||
mask_info = {}
|
mask_info = {}
|
||||||
mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(attention_mask, return_indices=False)
|
mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(padding_mask, return_indices=False)
|
||||||
|
|
||||||
# Unpad, split seq-wise, then pad back to (B, max_seqlen // sp_size)
|
# Unpad, split seq-wise, then pad to (B, max_seqlen // sp_size)
|
||||||
# Split mask to compute local nonzero position indices
|
|
||||||
# (B, Sq) -> (B, max_seqlen // sp_size)
|
# (B, Sq) -> (B, max_seqlen // sp_size)
|
||||||
attention_mask = attention_mask[:, : mask_info["max_seqlen"]]
|
padding_mask = padding_mask[:, : mask_info["max_seqlen"]]
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]]
|
inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]]
|
||||||
inputs_embeds = split_varlen_zigzag(
|
inputs_embeds = split_varlen_zigzag(
|
||||||
|
@ -1173,11 +1199,12 @@ class RingAttention(torch.autograd.Function):
|
||||||
mask_info["cu_seqlens"],
|
mask_info["cu_seqlens"],
|
||||||
sp_group,
|
sp_group,
|
||||||
mask_info["max_seqlen"],
|
mask_info["max_seqlen"],
|
||||||
is_2d=is_2d,
|
is_batched_seq=is_batched_seq,
|
||||||
is_label=is_label,
|
is_label=is_label,
|
||||||
)
|
)
|
||||||
attention_mask = split_varlen_zigzag(
|
# Split mask to get local nonzero seq positions
|
||||||
attention_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_2d=is_2d
|
padding_mask = split_varlen_zigzag(
|
||||||
|
padding_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_batched_seq=is_batched_seq
|
||||||
)
|
)
|
||||||
|
|
||||||
if position_ids is not None:
|
if position_ids is not None:
|
||||||
|
@ -1190,7 +1217,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
)
|
)
|
||||||
|
|
||||||
mask_info["max_seqlen"] //= sp_size
|
mask_info["max_seqlen"] //= sp_size
|
||||||
mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
mask_info["valid_indices"] = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
||||||
mask_info["cu_seqlens"] //= sp_size
|
mask_info["cu_seqlens"] //= sp_size
|
||||||
mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
||||||
return inputs_embeds, mask_info, position_ids
|
return inputs_embeds, mask_info, position_ids
|
||||||
|
|
|
@ -295,8 +295,8 @@ def split_batch_zigzag(
|
||||||
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False
|
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False
|
||||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask
|
Split the input sequence batch . Naively spliting the attention mask in the causal setting
|
||||||
in the causal setting will result in the preceding ranks having much less workload.
|
will result in the preceding ranks having much less workload.
|
||||||
We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
|
We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
|
||||||
For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
|
For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
|
||||||
|
|
||||||
|
@ -346,40 +346,42 @@ def split_varlen_zigzag(
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
sp_group: ProcessGroup,
|
sp_group: ProcessGroup,
|
||||||
max_seqlen: int = 0,
|
max_seqlen: int = 0,
|
||||||
is_2d: bool = False,
|
is_batched_seq: bool = False,
|
||||||
is_label: bool = False,
|
is_label: bool = False,
|
||||||
) -> Union[List[torch.Tensor], torch.Tensor]:
|
) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||||
"""Split each sequence in a batch of packed sequences in a zigzag fashion.
|
"""Split a packed seq/batch of padded sequences in a Zigzag fashion.
|
||||||
For each tensor in batch, return packed sequences if is_2d is False;
|
Different from split_batch_zigzag, inputs here have variable sequence lengths.
|
||||||
else return a padded batch of sequences.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d.
|
batch (List[torch.Tensor]): Packed sequences of shape (T, ...), or (B, Sq, ...) if is_batched_seq,
|
||||||
|
where T is the total number of tokens.
|
||||||
cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting.
|
cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting.
|
||||||
sp_group (ProcessGroup): The process group for sequence parallelism.
|
sp_group (ProcessGroup): The process group for sequence parallelism.
|
||||||
max_seqlen (int): The maximum sequence length in the batch before splitting.
|
max_seqlen (int): The maximum sequence length in the batch before splitting.
|
||||||
is_2d (bool): If True, then input has batch size and sequence length split into two dimensions.
|
is_batched_seq (bool): If True, then the input is a batch of sequences padded to the same len.
|
||||||
is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>).
|
is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size)
|
batch (List[torch.Tensor]): Packed sequences of shape (T, ..)
|
||||||
or (B, max_seqlen // sp_size, ...) if is_2d
|
or (B, max_seqlen // sp_size, ...) if is_batched_seq
|
||||||
"""
|
"""
|
||||||
sp_size = dist.get_world_size(sp_group)
|
sp_size = dist.get_world_size(sp_group)
|
||||||
sp_rank = dist.get_rank(sp_group)
|
sp_rank = dist.get_rank(sp_group)
|
||||||
if sp_size == 1:
|
if sp_size == 1:
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
if is_2d:
|
if is_batched_seq:
|
||||||
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
|
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
|
||||||
|
|
||||||
if isinstance(batch, torch.Tensor):
|
if isinstance(batch, torch.Tensor):
|
||||||
batch = [batch]
|
batch = [batch]
|
||||||
|
# seq: (B, Sq, h, n)
|
||||||
|
# seq = seq[:, :rank * (seqlen // sp_size), ...]
|
||||||
|
|
||||||
for i, packed_seq in enumerate(batch):
|
for i, packed_seq in enumerate(batch):
|
||||||
device = packed_seq.device
|
device = packed_seq.device
|
||||||
dtype = packed_seq.dtype
|
dtype = packed_seq.dtype
|
||||||
|
|
||||||
if is_2d:
|
if is_batched_seq:
|
||||||
assert max_seqlen % (sp_size * 2) == 0
|
assert max_seqlen % (sp_size * 2) == 0
|
||||||
# Recreate a padded tensor with the new max seqlen
|
# Recreate a padded tensor with the new max seqlen
|
||||||
shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
|
shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
|
||||||
|
@ -398,7 +400,7 @@ def split_varlen_zigzag(
|
||||||
seqlen % (2 * sp_size) == 0
|
seqlen % (2 * sp_size) == 0
|
||||||
), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"
|
), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"
|
||||||
|
|
||||||
if is_2d:
|
if is_batched_seq:
|
||||||
seq = packed_seq[j][:seqlen]
|
seq = packed_seq[j][:seqlen]
|
||||||
if is_label:
|
if is_label:
|
||||||
# Shift one position to the right for next token prediction
|
# Shift one position to the right for next token prediction
|
||||||
|
@ -415,7 +417,7 @@ def split_varlen_zigzag(
|
||||||
seq = seq.chunk(sp_size * 2)
|
seq = seq.chunk(sp_size * 2)
|
||||||
local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])
|
local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])
|
||||||
|
|
||||||
if is_2d:
|
if is_batched_seq:
|
||||||
batch[i] = local_seq.contiguous()
|
batch[i] = local_seq.contiguous()
|
||||||
else:
|
else:
|
||||||
batch[i] = torch.cat(local_seq, dim=0)
|
batch[i] = torch.cat(local_seq, dim=0)
|
||||||
|
|
Loading…
Reference in New Issue