mirror of https://github.com/hpcaitech/ColossalAI
1243 lines
55 KiB
Python
1243 lines
55 KiB
Python
from enum import Enum
|
|
from typing import Callable, Dict, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.distributed
|
|
import torch.distributed as dist
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
from packaging import version
|
|
|
|
from colossalai.kernel.kernel_loader import (
|
|
FlashAttentionDaoLoader,
|
|
FlashAttentionForFloatAndCustomMaskLoader,
|
|
FlashAttentionLoader,
|
|
FlashAttentionWithCustomMaskLoader,
|
|
KernelLoader,
|
|
)
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
from .utils import RingComm, get_half_index, split_varlen_zigzag
|
|
|
|
MEMORY_BOUND = 10 * 1e9
|
|
|
|
__all__ = [
|
|
"AttnMaskType",
|
|
"ColoAttention",
|
|
]
|
|
|
|
_flash_attn_forward = _flash_attn_backward = None
|
|
_unpad_input = _pad_input = None
|
|
|
|
|
|
class AttnMaskType(Enum):
|
|
CUSTOM = 0
|
|
PADDED = 1
|
|
CAUSAL = 2
|
|
PADDED_CAUSAL = 3
|
|
|
|
|
|
def invert_mask(mask: torch.Tensor) -> torch.Tensor:
|
|
"""Invert the mask tensor.
|
|
|
|
Args:
|
|
mask (torch.Tensor): Mask tensor. Shape should be [B, 1, Sq, Skv]
|
|
|
|
Returns:
|
|
torch.Tensor: Inverted mask tensor.
|
|
"""
|
|
inverted_mask = 1.0 - mask
|
|
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(mask.dtype).min)
|
|
|
|
|
|
# adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
|
|
def get_pad_info(
|
|
padding_mask: torch.Tensor, invert: Optional[bool] = False, return_indices: Optional[bool] = True
|
|
) -> Tuple[int, torch.Tensor, torch.Tensor]:
|
|
"""Get padding information from padding mask.
|
|
|
|
Args:
|
|
padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, Skv]
|
|
invert (Optional[bool], optional): Whether to reverse the padding mask.
|
|
return_indices (Optional[bool], optional): Whether to return the indices of non-masked tokens.
|
|
|
|
Returns:
|
|
max_seqlen_in_batch (int): Maximum sequence length in the batch.
|
|
cu_seqlens (torch.Tensor): Shape [B+1]. Cumulative sequence lengths of the sequences in the batch.
|
|
indices (torch.Tensor): Shape [total_nonzero]. The indices of non-masked tokens from the flattened input sequence.
|
|
"""
|
|
if invert:
|
|
padding_mask = padding_mask.logical_not()
|
|
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
|
|
if return_indices:
|
|
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
|
|
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
|
if return_indices:
|
|
return max_seqlen_in_batch, cu_seqlens, indices
|
|
return max_seqlen_in_batch, cu_seqlens
|
|
|
|
|
|
class ColoAttention:
|
|
_kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
|
|
_flash_kernel_dispatch: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
|
|
|
|
@staticmethod
|
|
def _init_kernels_dispatch():
|
|
if ColoAttention._kernel_dispatch_map is None:
|
|
# fp16/bf16
|
|
half_dispatch_map = {
|
|
None: FlashAttentionLoader(),
|
|
AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(),
|
|
AttnMaskType.PADDED: FlashAttentionLoader(),
|
|
AttnMaskType.CAUSAL: FlashAttentionLoader(),
|
|
AttnMaskType.PADDED_CAUSAL: FlashAttentionLoader(),
|
|
}
|
|
# fp32
|
|
float_dispatch_map = {
|
|
None: FlashAttentionForFloatAndCustomMaskLoader(),
|
|
AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(),
|
|
AttnMaskType.PADDED: FlashAttentionForFloatAndCustomMaskLoader(),
|
|
AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
|
|
AttnMaskType.PADDED_CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
|
|
}
|
|
ColoAttention._kernel_dispatch_map = {
|
|
torch.float16: half_dispatch_map,
|
|
torch.bfloat16: half_dispatch_map,
|
|
torch.float32: float_dispatch_map,
|
|
}
|
|
if ColoAttention._flash_kernel_dispatch is None:
|
|
ColoAttention._flash_kernel_dispatch = FlashAttentionDaoLoader()
|
|
|
|
@staticmethod
|
|
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable:
|
|
ColoAttention._init_kernels_dispatch()
|
|
if (
|
|
dtype not in ColoAttention._kernel_dispatch_map
|
|
or mask_type not in ColoAttention._kernel_dispatch_map[dtype]
|
|
):
|
|
raise ValueError(
|
|
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
|
|
)
|
|
|
|
if size >= MEMORY_BOUND:
|
|
if isinstance(ColoAttention._flash_kernel_dispatch, KernelLoader):
|
|
ColoAttention._flash_kernel_dispatch = ColoAttention._flash_kernel_dispatch.load()
|
|
# lazy load
|
|
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
|
|
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
|
|
mask_type
|
|
].load()
|
|
|
|
if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL):
|
|
return ColoAttention._flash_kernel_dispatch
|
|
else:
|
|
return ColoAttention._kernel_dispatch_map[dtype][mask_type]
|
|
|
|
@staticmethod
|
|
def prepare_attn_kwargs(
|
|
shape_4d: Tuple[int],
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
q_padding_mask: Optional[torch.Tensor] = None,
|
|
kv_padding_mask: Optional[torch.Tensor] = None,
|
|
is_causal: bool = False,
|
|
invert: bool = True,
|
|
) -> Dict[str, torch.Tensor]:
|
|
"""Return a dictionary of keyword arguments for attention function. It supports 4 mask type.
|
|
1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves.
|
|
2. padded mask: recv padding mask and is_causal=False, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}.
|
|
3. causal mask: no padding mask and is_causal=True, return {attention_mask, attention_mask_type}.
|
|
4. padded causal mask: recv padding mask and is_causal=True, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}.
|
|
|
|
Args:
|
|
shape_4d (Tuple[int]): Should be (B, 1, Sq, Skv)
|
|
dtype (torch.dtype): Dtype of attention mask, generally should be ``hidden_states.dtype``
|
|
device (torch.device): Device of attention mask, generally should be ``hidden_states.device``
|
|
q_padding_mask (Optional[torch.Tensor], optional): Padding mask of query. It should be a long tensor or int tensor.
|
|
The shape should be [B, Sq]. ``1`` means valid token, and ``0`` means padding token. Defaults to None.
|
|
kv_padding_mask (Optional[torch.Tensor], optional): Padding mask of key and value. It should be a long tensor or int tensor.
|
|
The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token.
|
|
If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None.
|
|
is_causal (bool, optional): Whether to use causal attention mask. Defaults to False.
|
|
invert_mask (bool, optional): Whether to invert the mask. Defaults to True.
|
|
Returns:
|
|
Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function.
|
|
"""
|
|
if q_padding_mask is None and not is_causal:
|
|
return {}
|
|
assert len(shape_4d) == 4 and shape_4d[1] == 1
|
|
b, _, s_q, s_kv = shape_4d
|
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
|
memory_size = s_q * s_kv * element_size
|
|
outputs = {}
|
|
if (q_padding_mask is None or q_padding_mask.bool().all()) and (
|
|
kv_padding_mask is None or kv_padding_mask.bool().all()
|
|
):
|
|
# no padding
|
|
assert is_causal
|
|
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
|
if memory_size < MEMORY_BOUND:
|
|
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
|
|
if s_q != 1:
|
|
attention_mask.tril_(diagonal=0)
|
|
attention_mask = attention_mask.expand(b, s_q, s_kv)
|
|
else:
|
|
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
|
else:
|
|
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
|
|
if kv_padding_mask is None:
|
|
# self attention
|
|
kv_padding_mask = q_padding_mask
|
|
max_seqlen_kv, cu_seqlens_kv, kv_indices = max_seqlen_q, cu_seqlens_q, q_indices
|
|
else:
|
|
max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
|
|
assert kv_padding_mask.shape == (
|
|
b,
|
|
s_kv,
|
|
), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
|
|
outputs.update(
|
|
{
|
|
"cu_seqlens_q": cu_seqlens_q,
|
|
"cu_seqlens_kv": cu_seqlens_kv,
|
|
"max_seqlen_q": max_seqlen_q,
|
|
"max_seqlen_kv": max_seqlen_kv,
|
|
"q_indices": q_indices,
|
|
"kv_indices": kv_indices,
|
|
}
|
|
)
|
|
if is_causal:
|
|
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
|
if memory_size < MEMORY_BOUND:
|
|
if s_q != 1:
|
|
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
|
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
|
|
else:
|
|
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
|
else:
|
|
outputs["attention_mask_type"] = AttnMaskType.PADDED
|
|
if memory_size < MEMORY_BOUND:
|
|
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
|
|
|
if invert:
|
|
attention_mask = invert_mask(attention_mask).unsqueeze(1)
|
|
outputs["attention_mask"] = attention_mask
|
|
return outputs
|
|
|
|
@staticmethod
|
|
def attention(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
attention_mask_type: AttnMaskType = AttnMaskType.CUSTOM,
|
|
cu_seqlens_q: Optional[torch.Tensor] = None,
|
|
cu_seqlens_kv: Optional[torch.Tensor] = None,
|
|
max_seqlen_q: Optional[int] = None,
|
|
max_seqlen_kv: Optional[int] = None,
|
|
q_indices: Optional[torch.Tensor] = None,
|
|
kv_indices: Optional[torch.Tensor] = None,
|
|
dropout_p: float = 0.0,
|
|
scale: Optional[float] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
"""Flash Attention function. It supports 4 mask type.
|
|
1. custom mask: recv attention_mask
|
|
2. padded mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
|
|
3. causal mask: recv attention_mask, attention_mask_type
|
|
4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
|
|
|
|
Args:
|
|
q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
|
|
k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Skv, D]
|
|
v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Skv, D]
|
|
attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None.
|
|
attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM.
|
|
cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths
|
|
of the sequences in the batch, used to index into q.
|
|
Shape should be [B+1]. Defaults to None.
|
|
cu_seqlens_kv (Optional[torch.Tensor], optional): The cumulative sequence lengths
|
|
of the sequences in the batch, used to index into kv.
|
|
Shape should be [B+1]. Defaults to None.
|
|
max_seqlen_q (Optional[int], optional): Maximum query sequence length in the batch. Defaults to None.
|
|
max_seqlen_kv (Optional[int], optional): Maximum key/value sequence length in the batch. Defaults to None.
|
|
indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from the flattened input sequence.
|
|
Shape should be [NUM_TOKENS]. Defaults to None.
|
|
dropout_p (float, optional): Dropout probability. Defaults to 0.0.
|
|
scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None.
|
|
|
|
Returns:
|
|
torch.Tensor: Output tensor. Shape should be [B, nHeads, Sq, D]
|
|
"""
|
|
# known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan
|
|
# this case is usaul when padding mask is used and self attention is performed
|
|
# thus, we don't use sdpa when padding mask is used
|
|
# sanity check
|
|
if attention_mask is not None:
|
|
assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor."
|
|
if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL):
|
|
assert (
|
|
cu_seqlens_q is None
|
|
and cu_seqlens_kv is None
|
|
and max_seqlen_q is None
|
|
and max_seqlen_kv is None
|
|
and q_indices is None
|
|
and kv_indices is None
|
|
)
|
|
if attention_mask_type == AttnMaskType.CUSTOM:
|
|
assert not torch.all(attention_mask != 0, dim=-1).any()
|
|
elif attention_mask_type in (
|
|
AttnMaskType.PADDED,
|
|
AttnMaskType.PADDED_CAUSAL,
|
|
):
|
|
assert (
|
|
cu_seqlens_q is not None
|
|
and cu_seqlens_kv is not None
|
|
and max_seqlen_q is not None
|
|
and max_seqlen_kv is not None
|
|
and q_indices is not None
|
|
and kv_indices is not None
|
|
)
|
|
else:
|
|
# if attention_mask is None, attention_mask_type should be the default value
|
|
assert attention_mask_type == AttnMaskType.CUSTOM
|
|
|
|
# kernel dispatch
|
|
b, _, s_q, _ = q.shape
|
|
b, _, s_kv, _ = v.shape
|
|
element_size = torch.tensor([], dtype=q.dtype).element_size()
|
|
memory_size = s_q * s_kv * element_size
|
|
mask_type = attention_mask_type if attention_mask is not None else None
|
|
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size)
|
|
is_causal = attention_mask is not None and attention_mask_type in (
|
|
AttnMaskType.CAUSAL,
|
|
AttnMaskType.PADDED_CAUSAL,
|
|
)
|
|
return attn_func(
|
|
q,
|
|
k,
|
|
v,
|
|
dropout_p=dropout_p,
|
|
scale=scale,
|
|
attention_mask=attention_mask,
|
|
is_causal=is_causal,
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
cu_seqlens_kv=cu_seqlens_kv,
|
|
max_seqlen_q=max_seqlen_q,
|
|
max_seqlen_kv=max_seqlen_kv,
|
|
q_indices=q_indices,
|
|
kv_indices=kv_indices,
|
|
)
|
|
|
|
|
|
def _load_varlen_helpers():
|
|
"""Helper to load functions for padding and unpadding packed sequences.
|
|
Use only when flash attn is installed
|
|
"""
|
|
global _pad_input, _unpad_input
|
|
# Flash attn claims this is more efficient than torch's bool indexing due to avoiding
|
|
# broadcast
|
|
if _pad_input is None or _unpad_input is None:
|
|
try:
|
|
from flash_attn.bert_padding import index_first_axis, pad_input
|
|
|
|
def unpad_input(hidden_states: torch.Tensor, indices: torch.Tensor):
|
|
return index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices)
|
|
|
|
_pad_input = pad_input
|
|
_unpad_input = unpad_input
|
|
except ImportError as e:
|
|
raise RuntimeError(
|
|
f"Flash Attention is not installed. You can install it via 'pip install flash-attn --no-build-isolation'"
|
|
) from e
|
|
|
|
|
|
def _load_flash_attn():
|
|
"""A light-weight loader to check whether flash-attn is installed.
|
|
Can't use ColoAttention._dispatch_kernel because we mutate the backward pass
|
|
"""
|
|
global _flash_attn_forward, _flash_attn_backward
|
|
if _flash_attn_forward is None or _flash_attn_backward is None:
|
|
try:
|
|
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward
|
|
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward
|
|
except ImportError as e:
|
|
raise RuntimeError(
|
|
f"Flash Attention is not installed. You can install it via 'pip install flash-attn --no-build-isolation'"
|
|
) from e
|
|
|
|
_load_varlen_helpers()
|
|
|
|
|
|
# NOTE: This can cause spawned processes to hang on exit
|
|
# with python 3.9
|
|
@torch.compile()
|
|
def _rescale_out_lse(out, block_out, lse, block_lse):
|
|
"""
|
|
Compute the new attention denominator:
|
|
exp(lse) + exp(block_lse) = exp(max_scale) * (exp(min_scale - max_scale) + 1)
|
|
Args:
|
|
out: (T, H, D)
|
|
block_out: (T, H, D)
|
|
lse: (H, T, 1)
|
|
block_lse: (H, T, 1)
|
|
"""
|
|
|
|
# min_scale = torch.min(lse, block_lse)
|
|
# max_scale = torch.max(lse, block_lse)
|
|
# new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
|
|
|
|
# NOTE: directly assigning to .data here is buggy
|
|
# probably due to casting dtypes/strides
|
|
new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
|
|
|
|
new_block_lse = torch.exp(block_lse - new_lse)
|
|
out = (torch.exp(lse - new_lse) * out + new_block_lse * block_out).to(out)
|
|
lse = new_lse
|
|
|
|
# Equivalent to the above
|
|
# See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
|
|
# out = (out - F.sigmoid(block_lse - lse) * (out - block_out))
|
|
# lse = (lse - F.logsigmoid(lse - block_lse))
|
|
return out, lse
|
|
|
|
|
|
class RingAttention(torch.autograd.Function):
|
|
"""Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context`
|
|
(https://arxiv.org/abs/2310.01889).
|
|
For load-balancing we adopted the "zigzag" attention scheme from https://github.com/zhuzilin/ring-flash-attention/tree/main
|
|
For portable integration with more models, we don't follow the spirit of "block-wise FNN" in the original paper,
|
|
which requires fusing FFN with the Flash Attention kernel/function (see https://arxiv.org/pdf/2305.19370;
|
|
implemented in Jax and not optimized).
|
|
We adopt the double ring topology from LoongTrain (https://arxiv.org/pdf/2406.18485) to fully utilize available
|
|
NICs on each node, by computing attention within a inner ring first and then sending all KVs to the next
|
|
ring at once.
|
|
"""
|
|
|
|
# Globle cache to avoid recomputation for same-lengthed sequences
|
|
CU_SEQLENS: torch.Tensor = None # [B+1]
|
|
TOTAL_SEQLEN: int = None
|
|
HALF_INDICES: Tuple = None
|
|
SUPPORTED_MASK_TYPES = (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL)
|
|
ATTN_DONE: torch.cuda.Event = None
|
|
SP_STREAM: torch.cuda.Stream = None
|
|
SP_GROUP: dist.ProcessGroup = None
|
|
|
|
# NOTE: Duplicating PGs for concurrent NCCL streams is a risky hack -- while it may increase throughput,
|
|
# both PyTorch and NCCL warn against this. (https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7)
|
|
# LoongTrain's original double ring impl. uses concurrent PGs
|
|
# (https://github.com/InternLM/InternEvo/blob/e52f2ffc9acf818e8f2b1f97dfc69ceb2f06e154/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py#L192)
|
|
# but I confirmed with Pytorch developers this can cause obscure "Software caused connection abort" errors.
|
|
# (https://github.com/pytorch/pytorch/issues/132852)
|
|
# NOTE: In general, a smarter idea is put as many P2P calls as possible into one `batch_isend_irecv`.
|
|
INNER_RING_GROUP: dist.ProcessGroup = None
|
|
# INNER_RING_GROUP_COPY: dist.ProcessGroup = None
|
|
INTER_RING_GROUP: dist.ProcessGroup = None
|
|
# INTER_RING_GROUP_COPY: dist.ProcessGroup = None
|
|
|
|
@staticmethod
|
|
def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None):
|
|
"""
|
|
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
|
|
shouldn't be larger than the number of NICs on each node.
|
|
Args:
|
|
sp_group (dist.ProcessGroup): Process group for sequence parallelism
|
|
inner_ring_size (Optional[int], optional): Inner ring size. Defaults to None.
|
|
Returns:
|
|
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
|
|
"""
|
|
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."
|
|
|
|
sp_group = pg_mesh.get_group_along_axis(sp_axis)
|
|
sp_size = dist.get_world_size(sp_group)
|
|
sp_rank = dist.get_rank(sp_group)
|
|
|
|
assert inner_ring_size is not None
|
|
|
|
assert (
|
|
inner_ring_size <= sp_size and sp_size % inner_ring_size == 0
|
|
), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
|
|
|
|
if inner_ring_size == sp_size:
|
|
return sp_group, sp_group
|
|
assert (
|
|
sp_size % inner_ring_size == 0
|
|
), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
|
|
logger = get_dist_logger()
|
|
logger.info(
|
|
f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!",
|
|
ranks=[0],
|
|
)
|
|
num_rings = sp_size // inner_ring_size
|
|
inner_ring_group = None
|
|
inter_ring_group = None
|
|
|
|
# Create inner ring groups
|
|
for i in range(inner_ring_size):
|
|
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
|
|
group = pg_mesh.get_group_along_axis(sp_axis, ranks)
|
|
if sp_rank in ranks:
|
|
inner_ring_group = group
|
|
|
|
# Create inter ring groups
|
|
for i in range(num_rings):
|
|
ranks = list(range(i, sp_size, num_rings))
|
|
group = pg_mesh.get_group_along_axis(sp_axis, ranks)
|
|
if sp_rank in ranks:
|
|
inter_ring_group = group
|
|
|
|
return inner_ring_group, inter_ring_group
|
|
|
|
@staticmethod
|
|
def attention(
|
|
q, # (B, H, Sq, D)
|
|
k,
|
|
v,
|
|
sp_axis,
|
|
attention_mask_type,
|
|
cu_seqlens=None,
|
|
max_seqlen=None,
|
|
valid_indices=None,
|
|
dropout_p=0.0,
|
|
softmax_scale=None,
|
|
deterministic=False,
|
|
return_softmax=False,
|
|
inner_ring_size=None,
|
|
pg_mesh=None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Ring Attention forward pass supporting variable-length sequences. When using varlen mode,
|
|
each sequence in the batch should have length divisible by sp_size * 2.
|
|
|
|
Args:
|
|
q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
|
|
k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D]
|
|
v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D]
|
|
sp_axis (Optional[int]): Sp axis for the global pg mesh.
|
|
sp_tream (torch.cuda.Stream): An different stream for output correction.
|
|
cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths
|
|
of the sequences in the batch, used to index into q.
|
|
Shape should be [B+1].
|
|
max_seqlen (Optional[int], optional): Maximum query sequence length in the batch.
|
|
valid_indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from get_pad_info.
|
|
Shape should be [t].
|
|
dropout_p (float, optional): Dropout probability. Defaults to 0.0.
|
|
softmax_scale (Optional[float], optional): Scaling factor applied prior to softmax.
|
|
deterministic (bool, optional): Whether to force deterministic backward pass. See https://github.com/Dao-AILab/flash-attention/issues/349
|
|
return_softmax (bool, optional): Whether to return the softmax denominator (logsumexp).
|
|
inner_ring_size (Optional[int], optional): Inner ring size of the 2D ring. By default use a heuristic to decide.
|
|
|
|
Returns:
|
|
out: Output tensor of shape [B, nHeads, Sq, D] or [T, nHeads, D] if pad_output is False.
|
|
softmax_lse: (if return_softmax is True) Softmax denominator (logsumexp).
|
|
Shape should be [total_q_seqlen, nHeads]
|
|
"""
|
|
# Check input args
|
|
_load_flash_attn()
|
|
if RingAttention.ATTN_DONE is None:
|
|
RingAttention.ATTN_DONE = torch.cuda.Event()
|
|
if RingAttention.SP_STREAM is None:
|
|
RingAttention.SP_STREAM = torch.cuda.Stream()
|
|
assert (
|
|
q.shape[2] == k.shape[2]
|
|
), "Q, K and V having different sequence lengths (inference or cross-attn)\
|
|
is not supported yet in training."
|
|
assert (
|
|
attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES
|
|
), f"Mask type {attention_mask_type} is not supported yet."
|
|
|
|
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."
|
|
|
|
clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg))
|
|
sp_group = pg_mesh.get_group_along_axis(sp_axis)
|
|
if inner_ring_size != None:
|
|
RingAttention.SP_GROUP = sp_group
|
|
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size)
|
|
RingAttention.INNER_RING_GROUP = inner_ring_group
|
|
RingAttention.INTER_RING_GROUP = inter_ring_group
|
|
else:
|
|
inner_ring_group = RingAttention.INNER_RING_GROUP
|
|
inter_ring_group = RingAttention.INTER_RING_GROUP
|
|
|
|
# (B, H, Sq, D) -> (B, Sq, H, D)
|
|
q, k, v = [x.transpose(1, 2).contiguous() for x in (q, k, v)]
|
|
pad_output = q.dim() == 4
|
|
|
|
# Get sequence length info for varlen forward
|
|
if attention_mask_type == AttnMaskType.CAUSAL:
|
|
# All sequences share the same length
|
|
b, sq, h, d = q.shape
|
|
max_seqlen = sq
|
|
# Cache to avoid recreation for a single sequence
|
|
if sq * b == RingAttention.TOTAL_SEQLEN:
|
|
cu_seqlens = RingAttention.CU_SEQLENS
|
|
else:
|
|
cu_seqlens = torch.arange(0, b * sq + 1, sq, device=q.device, dtype=torch.int32)
|
|
RingAttention.TOTAL_SEQLEN = b * sq
|
|
|
|
# "Packed" mode where sequences of different lengths are packed into [total_q_seqlen, H, D]
|
|
elif attention_mask_type == AttnMaskType.PADDED_CAUSAL:
|
|
assert (
|
|
cu_seqlens is not None and max_seqlen is not None and valid_indices is not None
|
|
), "Packed mode requires pre-computed cu_seqlens and max_seq_len."
|
|
if pad_output:
|
|
b, sq, h, d = q.shape
|
|
q, k, v = [_unpad_input(x, valid_indices) for x in (q, k, v)]
|
|
|
|
out, softmax_lse = RingAttention.apply(
|
|
q,
|
|
k,
|
|
v,
|
|
sp_group,
|
|
RingAttention.SP_STREAM,
|
|
cu_seqlens,
|
|
max_seqlen,
|
|
dropout_p,
|
|
softmax_scale,
|
|
deterministic,
|
|
return_softmax,
|
|
attention_mask_type == AttnMaskType.PADDED_CAUSAL,
|
|
inner_ring_group,
|
|
inter_ring_group,
|
|
)
|
|
|
|
if attention_mask_type == AttnMaskType.PADDED_CAUSAL:
|
|
if pad_output:
|
|
out = _pad_input(out, valid_indices, b, sq) # (T, ...) -> (B, Sq, ...)
|
|
out = out.transpose(1, 2) # (B, Sq, H, D) -> (B, H, Sq, D)
|
|
else:
|
|
out = out.transpose(1, 2)
|
|
|
|
if return_softmax:
|
|
return out, softmax_lse
|
|
return out
|
|
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
sp_group: dist.ProcessGroup,
|
|
sp_stream: torch.cuda.Stream,
|
|
cu_seqlens: torch.Tensor,
|
|
max_seqlen: int,
|
|
dropout_p: float = 0.0,
|
|
softmax_scale: Optional[float] = None,
|
|
deterministic: Optional[bool] = False,
|
|
return_softmax: Optional[bool] = False,
|
|
is_packed: Optional[bool] = False,
|
|
inner_ring_group: Optional[dist.ProcessGroup] = None,
|
|
inter_ring_group: Optional[dist.ProcessGroup] = None,
|
|
):
|
|
"""
|
|
Forward supporting both packed (varlen) and batched(fixed length, no padding) sequences.
|
|
No separate version for batched seq (hard to maintain), which incurs
|
|
some overhead in sequence splitting due to python for loops.
|
|
Uses two CUDA streams to overlap softmax denominator correction with next flash attn
|
|
(see comments below).
|
|
"""
|
|
cu_seqlens_q = cu_seqlens_kv = cu_seqlens
|
|
max_seqlen_q = max_seqlen_kv = max_seqlen
|
|
cu_seqlens_half = cu_seqlens // 2
|
|
max_seqlen_half = max_seqlen // 2
|
|
misc_kwargs = {
|
|
"alibi_slopes": None,
|
|
"softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale,
|
|
"dropout_p": dropout_p,
|
|
"block_table": None,
|
|
"softcap": 0.0,
|
|
"return_softmax": False,
|
|
}
|
|
import flash_attn
|
|
|
|
if version.parse(flash_attn.__version__) > version.parse("2.6.3"):
|
|
misc_kwargs["window_size_left"] = -1
|
|
misc_kwargs["window_size_right"] = -1
|
|
else:
|
|
misc_kwargs["window_size"] = (-1, -1)
|
|
|
|
if (
|
|
RingAttention.HALF_INDICES is not None
|
|
and cu_seqlens.shape == RingAttention.CU_SEQLENS.shape
|
|
and (cu_seqlens == RingAttention.CU_SEQLENS).all()
|
|
):
|
|
half_idx_front, half_idx_back = RingAttention.HALF_INDICES
|
|
else:
|
|
half_idx_front = get_half_index(cu_seqlens, front=True)
|
|
half_idx_back = get_half_index(cu_seqlens, front=False)
|
|
RingAttention.HALF_INDICES = (half_idx_front, half_idx_back)
|
|
RingAttention.CU_SEQLENS = cu_seqlens
|
|
|
|
if is_packed:
|
|
t, h, d = q.shape
|
|
else:
|
|
b, sq, h, d = q.shape
|
|
t = b * sq
|
|
# Be careful about GQA/MQA in reshape
|
|
q, k, v = [x.view(t, *x.shape[-2:]) for x in (q, k, v)]
|
|
|
|
if inner_ring_group is None or inter_ring_group is None:
|
|
# Use one ring if not specified
|
|
inner_ring_group = inter_ring_group = sp_group
|
|
|
|
sp_size = dist.get_world_size(sp_group)
|
|
sp_rank = dist.get_rank(sp_group)
|
|
|
|
# Create communicators corresponding to two CUDA streams
|
|
local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)]
|
|
inter_ring_comm = RingComm(inter_ring_group)
|
|
local_sp_size = dist.get_world_size(inner_ring_group)
|
|
local_sp_rank = dist.get_rank(inner_ring_group)
|
|
inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0
|
|
num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1
|
|
|
|
# Any type of indexing(but not slicing) copies to a new contiguous tensor,
|
|
# so only do it once
|
|
if sp_rank != sp_size - 1:
|
|
q1 = q[half_idx_back]
|
|
|
|
# Pre-allocate double buffer for overlapping and receiving next step's inputs
|
|
kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D)
|
|
kv_buffers.append(torch.empty_like(kv_buffers[0]))
|
|
|
|
# outputs
|
|
out = None
|
|
block_out = [None, None]
|
|
softmax_lse = [None, None]
|
|
block_softmax_lse = [None, None] # log sum exp, the denominator of softmax in attention
|
|
rng_states = [None for _ in range(sp_size)]
|
|
sp_streams = [torch.cuda.current_stream(), sp_stream]
|
|
|
|
# Helper to pass args to FA
|
|
def _forward(q, k, v, causal):
|
|
if version.parse(flash_attn.__version__) > version.parse("2.6.3"):
|
|
(out, softmax_lse, S_dmask, rng_state) = _flash_attn_forward(
|
|
q,
|
|
k,
|
|
v,
|
|
cu_seqlens_q if q.shape[0] == t else cu_seqlens_half,
|
|
cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half,
|
|
max_seqlen_q if q.shape[0] == t else max_seqlen_half,
|
|
max_seqlen_kv if k.shape[0] == t else max_seqlen_half,
|
|
causal=causal,
|
|
**misc_kwargs,
|
|
)
|
|
else:
|
|
(
|
|
_,
|
|
_,
|
|
_,
|
|
_,
|
|
out,
|
|
softmax_lse,
|
|
_,
|
|
rng_state,
|
|
) = _flash_attn_forward(
|
|
q,
|
|
k,
|
|
v,
|
|
cu_seqlens_q if q.shape[0] == t else cu_seqlens_half,
|
|
cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half,
|
|
max_seqlen_q if q.shape[0] == t else max_seqlen_half,
|
|
max_seqlen_kv if k.shape[0] == t else max_seqlen_half,
|
|
causal=causal,
|
|
**misc_kwargs,
|
|
)
|
|
return out, softmax_lse, rng_state
|
|
|
|
def _kv_comm(i):
|
|
# Avoid overwriting attn input when it shares mem with buffer
|
|
if not RingAttention.ATTN_DONE.query():
|
|
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
|
|
if i < local_sp_size - 1:
|
|
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
|
|
|
# Forward within a node
|
|
def _local_ring_forward():
|
|
# (Hopefully) overlap output correction with next flash attn
|
|
for i in range(local_sp_size):
|
|
with torch.cuda.stream(sp_streams[i % 2]):
|
|
# Wait for current kv from prev rank
|
|
# NOTE: waiting outside the current stream will NOT correctly synchronize.
|
|
if i > 0:
|
|
local_kv_comms[(i + 1) % 2].wait()
|
|
|
|
# Prefetch
|
|
if i == 0:
|
|
_kv_comm(i)
|
|
|
|
if i == 0:
|
|
# Compute with local KV; no mask
|
|
kv_block = kv_buffers[0]
|
|
q_block = q
|
|
(block_out[i % 2], block_softmax_lse[i % 2], rng_states[i]) = _forward( # (T, H, D) # (H, T)
|
|
q_block, kv_block[0], kv_block[1], causal=True
|
|
)
|
|
elif i <= local_sp_rank:
|
|
# Received the "surrounding" kv chunks
|
|
# Drop the second half of received kv
|
|
# (2, t // 2, H, D)
|
|
kv_block = kv_buffers[i % 2][:, half_idx_front]
|
|
q_block = q
|
|
(
|
|
block_out[i % 2], # (T, H, D)
|
|
block_softmax_lse[i % 2], # (H, T)
|
|
rng_states[i],
|
|
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
|
|
else:
|
|
# Received the inner kv chunks
|
|
# Drop the first half of q
|
|
kv_block = kv_buffers[i % 2]
|
|
q_block = q1
|
|
(
|
|
block_out[i % 2], # (T, H, D)
|
|
block_softmax_lse[i % 2], # (H, T)
|
|
rng_states[i],
|
|
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
|
|
RingAttention.ATTN_DONE.record()
|
|
# Pipeline the next KV comm with output correction instead of the next flash attn
|
|
# kernel, to minimize bubble when comm takes longer than attn.
|
|
_kv_comm(i + 1)
|
|
|
|
block_softmax_lse[i % 2] = (
|
|
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
|
) # (H, T) -> (T, H, 1)
|
|
assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
|
|
|
|
# Output and log sum exp correction.
|
|
# Ideally overlap this with the next flash attn kernel,
|
|
# since attn uses Tensor Core and rescale is element-wise, memory-bound and uses CUDA cores.
|
|
# (NOTE that this is the same as ping-pong scheduling idea in FA3)
|
|
# TODO However sometimes while the GPU has scheduled the next kernel,
|
|
# it's reluctant to launch it in overlap. Some potential causes:
|
|
# 1. need lower-level CUDA scheduling 2. further benchmark against Megatron-LM
|
|
# 3. register spilling by FA kernel.
|
|
if i == 0:
|
|
out = block_out[0]
|
|
softmax_lse = block_softmax_lse[0]
|
|
elif i <= local_sp_rank:
|
|
out, softmax_lse = _rescale_out_lse(
|
|
out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2]
|
|
)
|
|
else:
|
|
out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse(
|
|
out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2]
|
|
)
|
|
|
|
torch.cuda.current_stream().wait_stream(sp_stream)
|
|
return out, softmax_lse
|
|
|
|
# Forward for inter-node (the outer ring in 2D ring)
|
|
def _other_ring_forward(ring_num_idx, out, softmax_lse):
|
|
# Loop through the inner ring after receiving
|
|
# all new KVs from another ring
|
|
for i in range(local_sp_size):
|
|
with torch.cuda.stream(sp_streams[i % 2]):
|
|
# Send & recv KV
|
|
if i > 0:
|
|
local_kv_comms[(i + 1) % 2].wait()
|
|
|
|
# Prefetch
|
|
if i == 0:
|
|
_kv_comm(i)
|
|
|
|
if ring_num_idx > inter_ring_rank:
|
|
kv_block = kv_buffers[i % 2]
|
|
(
|
|
block_out[i % 2],
|
|
block_softmax_lse[i % 2],
|
|
rng_states[i + local_sp_size * ring_num_idx],
|
|
) = _forward(q1, kv_block[0], kv_block[1], causal=False)
|
|
RingAttention.ATTN_DONE.record()
|
|
|
|
_kv_comm(i + 1)
|
|
block_softmax_lse[i % 2] = (
|
|
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
|
)
|
|
out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse(
|
|
out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2]
|
|
)
|
|
else:
|
|
kv_block = kv_buffers[i % 2][:, half_idx_front]
|
|
(
|
|
block_out[i % 2],
|
|
block_softmax_lse[i % 2],
|
|
rng_states[i + local_sp_size * ring_num_idx],
|
|
) = _forward(q, kv_block[0], kv_block[1], causal=False)
|
|
RingAttention.ATTN_DONE.record()
|
|
|
|
_kv_comm(i + 1)
|
|
block_softmax_lse[i % 2] = (
|
|
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
|
)
|
|
out, softmax_lse = _rescale_out_lse(
|
|
out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2]
|
|
)
|
|
|
|
torch.cuda.current_stream().wait_stream(sp_stream)
|
|
return out, softmax_lse
|
|
|
|
# Send and recv KV between rings at once to maximize NIC util.
|
|
inter_ring_kv = None
|
|
for ring_num_idx in range(num_rings):
|
|
if ring_num_idx > 0:
|
|
inter_ring_comm.wait()
|
|
# Reset indices
|
|
kv_buffers[0] = inter_ring_kv
|
|
|
|
if ring_num_idx < num_rings - 1:
|
|
if ring_num_idx == 0:
|
|
to_send = kv_buffers[0]
|
|
else:
|
|
# The last received KV
|
|
to_send = kv_buffers[(local_sp_size - 1) % 2]
|
|
inter_ring_kv = inter_ring_comm.send_recv(to_send)
|
|
|
|
if ring_num_idx == 0:
|
|
out, softmax_lse = _local_ring_forward()
|
|
else:
|
|
out, softmax_lse = _other_ring_forward(ring_num_idx, out, softmax_lse)
|
|
|
|
out = out.to(q.dtype)
|
|
if not is_packed:
|
|
out = out.view(b, sq, h, d)
|
|
q, k, v = [x.view(b, sq, *x.shape[-2:]) for x in (q, k, v)] # (T, H, D) -> (B, Sq, H, D)
|
|
softmax_lse = softmax_lse.squeeze(-1)
|
|
|
|
ctx.sp_group = sp_group
|
|
ctx.max_seqlen_q = ctx.max_seqlen_kv = max_seqlen
|
|
misc_kwargs["deterministic"] = deterministic
|
|
del misc_kwargs["return_softmax"]
|
|
ctx.misc_kwargs = misc_kwargs
|
|
ctx.is_packed = is_packed
|
|
|
|
ctx.kv_group = inner_ring_group
|
|
ctx.inter_kv_group = inter_ring_group
|
|
|
|
ctx.save_for_backward(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
softmax_lse.transpose(0, 1).contiguous(), # (T, H) -> (H, T)
|
|
cu_seqlens_q,
|
|
cu_seqlens_kv,
|
|
half_idx_front,
|
|
half_idx_back,
|
|
*rng_states,
|
|
)
|
|
|
|
if return_softmax:
|
|
return out, softmax_lse
|
|
return out, None
|
|
|
|
def backward(ctx, dout, _):
|
|
"""
|
|
During backward, we accumulate q grads on each rank locally, but iterate kv and their grads
|
|
over all ranks for accumulation. We avoid using two streams due to backward using doubled
|
|
buffers and more comm cost.
|
|
"""
|
|
(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9]
|
|
rng_states = ctx.saved_tensors[9:]
|
|
|
|
is_packed = ctx.is_packed
|
|
max_seqlen_q = ctx.max_seqlen_q
|
|
max_seqlen_kv = ctx.max_seqlen_kv
|
|
cu_seqlens_half = cu_seqlens_q // 2
|
|
max_seqlen_half = max_seqlen_q // 2
|
|
misc_kwargs = ctx.misc_kwargs
|
|
del misc_kwargs["block_table"]
|
|
|
|
assert (
|
|
out.shape == dout.shape == q.shape
|
|
), f"out {out.shape} and dout {dout.shape} should have the same shape ({q.shape})."
|
|
|
|
if is_packed:
|
|
t, h, d = q.shape
|
|
else:
|
|
b, sq, h, d = q.shape
|
|
t = b * sq
|
|
q, k, v, out, dout = [x.view(t, *x.shape[-2:]) for x in (q, k, v, out, dout)]
|
|
|
|
# Sequence parallel args
|
|
sp_group = ctx.sp_group
|
|
local_kv_group = ctx.kv_group
|
|
inter_kv_group = ctx.inter_kv_group
|
|
|
|
local_sp_rank = dist.get_rank(sp_group)
|
|
sp_size = dist.get_world_size(sp_group)
|
|
|
|
# NOTE: Using separate streams (PG) for concurrent kv and dkv comm may
|
|
# cause NCCL "software caused connection abort" here...
|
|
local_kv_comm = RingComm(local_kv_group)
|
|
local_dkv_comm = RingComm(local_kv_group)
|
|
inter_kv_comm = RingComm(inter_kv_group)
|
|
inter_dkv_comm = RingComm(inter_kv_group)
|
|
local_sp_size = dist.get_world_size(local_kv_group)
|
|
local_sp_rank = dist.get_rank(local_kv_group)
|
|
|
|
if dist.get_world_size(inter_kv_group) != sp_size:
|
|
num_rings = dist.get_world_size(inter_kv_group)
|
|
inter_ring_rank = dist.get_rank(inter_kv_group)
|
|
else:
|
|
num_rings = 1
|
|
inter_ring_rank = 0
|
|
|
|
if local_sp_rank != sp_size - 1:
|
|
softmax_lse1 = softmax_lse[:, half_idx_back]
|
|
dout = dout.contiguous()
|
|
|
|
# Double comm buffers for sending and receiving kv
|
|
kv_buffers = [torch.stack((k, v))] # (2, T, H, D)
|
|
kv_buffers.append(torch.empty_like(kv_buffers[0]))
|
|
|
|
dq = None # (T, H, D)
|
|
# Intermediate outputs
|
|
dq_block = torch.empty_like(q) # (T, H, D)
|
|
dk_block = torch.empty_like(k) # (T, H, D)
|
|
dv_block = torch.empty_like(v) # (T, H, D)
|
|
dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D)
|
|
del k, v
|
|
|
|
# Helper to pass args to FA
|
|
def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal):
|
|
_flash_attn_backward(
|
|
dout,
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
softmax_lse,
|
|
dq,
|
|
dk,
|
|
dv,
|
|
cu_seqlens_q if dq.shape[0] == t else cu_seqlens_half,
|
|
cu_seqlens_kv if dk.shape[0] == t else cu_seqlens_half,
|
|
max_seqlen_q if dq.shape[0] == t else max_seqlen_half,
|
|
max_seqlen_kv if dk.shape[0] == t else max_seqlen_half,
|
|
causal=causal,
|
|
rng_state=rng_state,
|
|
**misc_kwargs,
|
|
)
|
|
|
|
# Backward within a node
|
|
def _local_ring_backward():
|
|
for i in range(local_sp_size):
|
|
if i > 0:
|
|
local_kv_comm.wait()
|
|
|
|
if i < local_sp_size - 1:
|
|
# Send kv to next rank for backward
|
|
local_kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
|
|
|
if i == 0:
|
|
# Backward with local kv
|
|
k_, v_ = kv_buffers[i % 2]
|
|
q_, dout_, out_ = q, dout, out
|
|
dq_, dk_, dv_ = dq_block, dk_block, dv_block
|
|
_backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_states[i], causal=True)
|
|
|
|
elif i <= local_sp_rank:
|
|
# Drop the second half of kv
|
|
# (T, H, D) -> (T // 2, H, D)
|
|
k_, v_ = [x[half_idx_front] for x in kv_buffers[i % 2]]
|
|
dk_, dv_ = [x[: t // 2] for x in (dk_block, dv_block)]
|
|
dq_, q_, out_, dout_ = (dq_block, q, out, dout)
|
|
_backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_states[i], causal=False)
|
|
|
|
else:
|
|
# Drop the first half of q
|
|
k_, v_ = kv_buffers[i % 2]
|
|
dk_, dv_ = dk_block, dv_block
|
|
q_, out_, dout_ = [x[half_idx_back] for x in (q, out, dout)]
|
|
dq_ = dq_block[: t // 2]
|
|
_backward(dout_, q_, k_, v_, out_, softmax_lse1, dq_, dk_, dv_, rng_states[i], causal=False)
|
|
|
|
# Accumulate grads
|
|
if i == 0:
|
|
dq = dq_block.float()
|
|
dkv_buffers[i % 2][0] = dk_block.float()
|
|
dkv_buffers[i % 2][1] = dv_block.float()
|
|
else:
|
|
# Accumulate local dq
|
|
if i <= local_sp_rank:
|
|
dq += dq_ # (T, H, D)
|
|
else:
|
|
dq[half_idx_back] += dq_
|
|
|
|
# Wait for mobile kv grad accumulators
|
|
local_dkv_comm.wait()
|
|
|
|
if i <= local_sp_rank:
|
|
# q blocks "surrounded" by kv blocks
|
|
dkv_buffers[i % 2][0][half_idx_front] += dk_
|
|
dkv_buffers[i % 2][1][half_idx_front] += dv_
|
|
else:
|
|
# q blocks "surrounding" kv blocks
|
|
dkv_buffers[i % 2][0] += dk_
|
|
dkv_buffers[i % 2][1] += dv_
|
|
local_dkv_comm.send_recv(send_tensor=dkv_buffers[i % 2], recv_tensor=dkv_buffers[(i + 1) % 2])
|
|
|
|
local_dkv_comm.wait()
|
|
dkv_recv = dkv_buffers[local_sp_size % 2]
|
|
dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
|
|
return dq, dkv_recv, dkv_send
|
|
|
|
# Backward for inter-node (the outer ring in 2D ring)
|
|
def _other_ring_backward(ring_num_idx, dq):
|
|
if ring_num_idx > inter_ring_rank:
|
|
# Indexing is expensive
|
|
q_, out_, dout_ = [x[half_idx_back] for x in (q, out, dout)]
|
|
else:
|
|
q_, out_, dout_ = (q, out, dout)
|
|
|
|
for i in range(local_sp_size):
|
|
if i > 0:
|
|
local_kv_comm.wait()
|
|
|
|
if i < local_sp_size - 1:
|
|
local_kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
|
|
|
rng_state = rng_states[i + local_sp_size * ring_num_idx]
|
|
if ring_num_idx > inter_ring_rank:
|
|
k_, v_ = kv_buffers[i % 2]
|
|
dk_, dv_ = dk_block, dv_block
|
|
dq_ = dq_block[: t // 2]
|
|
_backward(dout_, q_, k_, v_, out_, softmax_lse1, dq_, dk_, dv_, rng_state, causal=False)
|
|
|
|
dq[half_idx_back] += dq_
|
|
if i > 0:
|
|
local_dkv_comm.wait()
|
|
else:
|
|
inter_dkv_comm.wait()
|
|
|
|
dkv_buffers[i % 2][0] += dk_
|
|
dkv_buffers[i % 2][1] += dv_
|
|
else:
|
|
k_, v_ = [x[half_idx_front] for x in kv_buffers[i % 2]]
|
|
dk_, dv_ = [x[: t // 2] for x in (dk_block, dv_block)]
|
|
dq_ = dq_block
|
|
_backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_state, causal=False)
|
|
|
|
dq += dq_
|
|
if i > 0:
|
|
local_dkv_comm.wait()
|
|
else:
|
|
inter_dkv_comm.wait()
|
|
|
|
dkv_buffers[i % 2][0][half_idx_front] += dk_
|
|
dkv_buffers[i % 2][1][half_idx_front] += dv_
|
|
|
|
local_dkv_comm.send_recv(send_tensor=dkv_buffers[i % 2], recv_tensor=dkv_buffers[(i + 1) % 2])
|
|
|
|
local_dkv_comm.wait()
|
|
dkv_recv = dkv_buffers[local_sp_size % 2]
|
|
dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
|
|
return dq, dkv_recv, dkv_send
|
|
|
|
inter_ring_kv = None
|
|
for ring_num_idx in range(num_rings):
|
|
if ring_num_idx > 0:
|
|
inter_kv_comm.wait()
|
|
kv_buffers[0] = inter_ring_kv
|
|
|
|
if ring_num_idx < num_rings - 1:
|
|
# Re-allocate a buffer in each inter-ring step
|
|
inter_ring_kv = inter_kv_comm.send_recv(kv_buffers[0])
|
|
|
|
if ring_num_idx == 0:
|
|
dq, dkv_recv, dkv_send = _local_ring_backward()
|
|
else:
|
|
dq, dkv_recv, dkv_send = _other_ring_backward(ring_num_idx, dq)
|
|
|
|
if num_rings > 1:
|
|
# Reuse the local buffers
|
|
inter_dkv_comm.send_recv(send_tensor=dkv_recv, recv_tensor=dkv_send)
|
|
# Reset indices
|
|
dkv_buffers[0] = dkv_send
|
|
dkv_buffers[1] = dkv_recv
|
|
if ring_num_idx == num_rings - 1:
|
|
inter_dkv_comm.wait()
|
|
dkv_recv = dkv_buffers[0]
|
|
|
|
dq, dk, dv = [x.to(q.dtype) for x in (dq, *dkv_recv)]
|
|
if not is_packed:
|
|
dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)]
|
|
|
|
return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None)
|
|
|
|
@staticmethod
|
|
def prepare_varlen_batch(
|
|
padding_mask: torch.Tensor,
|
|
sp_group: dist.ProcessGroup,
|
|
inputs_embeds: torch.Tensor = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
is_label: bool = False,
|
|
is_batched_seq: bool = True,
|
|
):
|
|
# TODO: support setting a batch dim (fix packing length) for packed mode, so that
|
|
# DP can be used (needs to modify dataloader too)
|
|
"""
|
|
Preprocess a batch of padded sequence by splitting input sequence by sp_size
|
|
seq-wise and packing them into one sequence. Updates the mask info accordingly.
|
|
Args:
|
|
padding_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked.
|
|
sp_group (dist.ProcessGroup): Process group for sequence parallelism
|
|
inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...]
|
|
position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None.
|
|
is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first
|
|
token of each sequence.
|
|
is_batched_seq (bool, optional): If True, then the input is a batch of (potentially padded) sequences
|
|
of shape [B, Sq, ...]; else a packed sequence of shape [T, ...].
|
|
|
|
Returns:
|
|
inputs_embeds (torch.Tensor):
|
|
Packed input embeddings of shape [B, Sq // sp_size, ...] if is_batched_seq, else [T, ...].
|
|
mask_info (Dict[str, Any]):
|
|
A dictionary containing mask info.
|
|
position_ids (torch.Tensor):
|
|
Packed position ids of shape [..., Sq // sp_size].
|
|
|
|
"""
|
|
_load_varlen_helpers()
|
|
sp_size = dist.get_world_size(group=sp_group)
|
|
sp_rank = dist.get_rank(group=sp_group)
|
|
mask_info = {}
|
|
mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(padding_mask, return_indices=False)
|
|
|
|
# Unpad, split seq-wise, then pad to (B, max_seqlen // sp_size)
|
|
# (B, Sq) -> (B, max_seqlen // sp_size)
|
|
padding_mask = padding_mask[:, : mask_info["max_seqlen"]]
|
|
if inputs_embeds is not None:
|
|
inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]]
|
|
inputs_embeds = split_varlen_zigzag(
|
|
inputs_embeds,
|
|
mask_info["cu_seqlens"],
|
|
sp_group,
|
|
mask_info["max_seqlen"],
|
|
is_batched_seq=is_batched_seq,
|
|
is_label=is_label,
|
|
)
|
|
# Split mask to get local nonzero seq positions
|
|
padding_mask = split_varlen_zigzag(
|
|
padding_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_batched_seq=is_batched_seq
|
|
)
|
|
|
|
if position_ids is not None:
|
|
indices = torch.tensor([sp_rank, 2 * sp_size - sp_rank - 1], device=inputs_embeds.device)
|
|
position_ids = (
|
|
position_ids[..., : mask_info["max_seqlen"]] # unpad
|
|
.view(-1, sp_size * 2, mask_info["max_seqlen"] // (sp_size * 2))
|
|
.index_select(-2, indices)
|
|
.view(-1, mask_info["max_seqlen"] // sp_size)
|
|
)
|
|
|
|
mask_info["max_seqlen"] //= sp_size
|
|
mask_info["valid_indices"] = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
|
mask_info["cu_seqlens"] //= sp_size
|
|
mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
|
return inputs_embeds, mask_info, position_ids
|