from enum import Enum from typing import Callable, Dict, Optional, Tuple import torch import torch.distributed import torch.distributed as dist import torch.nn.functional as F from einops import rearrange from colossalai.kernel.kernel_loader import ( FlashAttentionDaoLoader, FlashAttentionForFloatAndCustomMaskLoader, FlashAttentionLoader, FlashAttentionWithCustomMaskLoader, KernelLoader, ) from colossalai.logging import get_dist_logger from .utils import RingComm, get_half_index, split_varlen_zigzag MEMORY_BOUND = 10 * 1e9 __all__ = [ "AttnMaskType", "ColoAttention", ] _flash_attn_forward = _flash_attn_backward = None _unpad_input = _pad_input = None class AttnMaskType(Enum): CUSTOM = 0 PADDED = 1 CAUSAL = 2 PADDED_CAUSAL = 3 def invert_mask(mask: torch.Tensor) -> torch.Tensor: """Invert the mask tensor. Args: mask (torch.Tensor): Mask tensor. Shape should be [B, 1, Sq, Skv] Returns: torch.Tensor: Inverted mask tensor. """ inverted_mask = 1.0 - mask return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(mask.dtype).min) # adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py def get_pad_info( padding_mask: torch.Tensor, invert: Optional[bool] = False, return_indices: Optional[bool] = True ) -> Tuple[int, torch.Tensor, torch.Tensor]: """Get padding information from padding mask. Args: padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, Skv] invert (Optional[bool], optional): Whether to reverse the padding mask. return_indices (Optional[bool], optional): Whether to return the indices of non-masked tokens. Returns: max_seqlen_in_batch (int): Maximum sequence length in the batch. cu_seqlens (torch.Tensor): Shape [B+1]. Cumulative sequence lengths of the sequences in the batch. indices (torch.Tensor): Shape [total_nonzero]. The indices of non-masked tokens from the flattened input sequence. """ if invert: padding_mask = padding_mask.logical_not() seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) if return_indices: indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) if return_indices: return max_seqlen_in_batch, cu_seqlens, indices return max_seqlen_in_batch, cu_seqlens class ColoAttention: _kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None _flash_kernel_dispatch: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None @staticmethod def _init_kernels_dispatch(): if ColoAttention._kernel_dispatch_map is None: # fp16/bf16 half_dispatch_map = { None: FlashAttentionLoader(), AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(), AttnMaskType.PADDED: FlashAttentionLoader(), AttnMaskType.CAUSAL: FlashAttentionLoader(), AttnMaskType.PADDED_CAUSAL: FlashAttentionLoader(), } # fp32 float_dispatch_map = { None: FlashAttentionForFloatAndCustomMaskLoader(), AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(), AttnMaskType.PADDED: FlashAttentionForFloatAndCustomMaskLoader(), AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(), AttnMaskType.PADDED_CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(), } ColoAttention._kernel_dispatch_map = { torch.float16: half_dispatch_map, torch.bfloat16: half_dispatch_map, torch.float32: float_dispatch_map, } if ColoAttention._flash_kernel_dispatch is None: ColoAttention._flash_kernel_dispatch = FlashAttentionDaoLoader() @staticmethod def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable: ColoAttention._init_kernels_dispatch() if ( dtype not in ColoAttention._kernel_dispatch_map or mask_type not in ColoAttention._kernel_dispatch_map[dtype] ): raise ValueError( "FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type) ) if size >= MEMORY_BOUND: if isinstance(ColoAttention._flash_kernel_dispatch, KernelLoader): ColoAttention._flash_kernel_dispatch = ColoAttention._flash_kernel_dispatch.load() # lazy load if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader): ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][ mask_type ].load() if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL): return ColoAttention._flash_kernel_dispatch else: return ColoAttention._kernel_dispatch_map[dtype][mask_type] @staticmethod def prepare_attn_kwargs( shape_4d: Tuple[int], dtype: torch.dtype, device: torch.device, q_padding_mask: Optional[torch.Tensor] = None, kv_padding_mask: Optional[torch.Tensor] = None, is_causal: bool = False, invert: bool = True, ) -> Dict[str, torch.Tensor]: """Return a dictionary of keyword arguments for attention function. It supports 4 mask type. 1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves. 2. padded mask: recv padding mask and is_causal=False, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}. 3. causal mask: no padding mask and is_causal=True, return {attention_mask, attention_mask_type}. 4. padded causal mask: recv padding mask and is_causal=True, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}. Args: shape_4d (Tuple[int]): Should be (B, 1, Sq, Skv) dtype (torch.dtype): Dtype of attention mask, generally should be ``hidden_states.dtype`` device (torch.device): Device of attention mask, generally should be ``hidden_states.device`` q_padding_mask (Optional[torch.Tensor], optional): Padding mask of query. It should be a long tensor or int tensor. The shape should be [B, Sq]. ``1`` means valid token, and ``0`` means padding token. Defaults to None. kv_padding_mask (Optional[torch.Tensor], optional): Padding mask of key and value. It should be a long tensor or int tensor. The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token. If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None. is_causal (bool, optional): Whether to use causal attention mask. Defaults to False. invert_mask (bool, optional): Whether to invert the mask. Defaults to True. Returns: Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function. """ if q_padding_mask is None and not is_causal: return {} assert len(shape_4d) == 4 and shape_4d[1] == 1 b, _, s_q, s_kv = shape_4d element_size = torch.tensor([], dtype=dtype).element_size() memory_size = s_q * s_kv * element_size outputs = {} if (q_padding_mask is None or q_padding_mask.bool().all()) and ( kv_padding_mask is None or kv_padding_mask.bool().all() ): # no padding assert is_causal outputs["attention_mask_type"] = AttnMaskType.CAUSAL if memory_size < MEMORY_BOUND: attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) if s_q != 1: attention_mask.tril_(diagonal=0) attention_mask = attention_mask.expand(b, s_q, s_kv) else: attention_mask = torch.empty((0,), dtype=dtype, device=device) else: max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) if kv_padding_mask is None: # self attention kv_padding_mask = q_padding_mask max_seqlen_kv, cu_seqlens_kv, kv_indices = max_seqlen_q, cu_seqlens_q, q_indices else: max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask) assert kv_padding_mask.shape == ( b, s_kv, ), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})" outputs.update( { "cu_seqlens_q": cu_seqlens_q, "cu_seqlens_kv": cu_seqlens_kv, "max_seqlen_q": max_seqlen_q, "max_seqlen_kv": max_seqlen_kv, "q_indices": q_indices, "kv_indices": kv_indices, } ) if is_causal: outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL if memory_size < MEMORY_BOUND: if s_q != 1: attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) else: attention_mask = torch.empty((0,), dtype=dtype, device=device) else: outputs["attention_mask_type"] = AttnMaskType.PADDED if memory_size < MEMORY_BOUND: attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) if invert: attention_mask = invert_mask(attention_mask).unsqueeze(1) outputs["attention_mask"] = attention_mask return outputs @staticmethod def attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, attention_mask_type: AttnMaskType = AttnMaskType.CUSTOM, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_kv: Optional[int] = None, q_indices: Optional[torch.Tensor] = None, kv_indices: Optional[torch.Tensor] = None, dropout_p: float = 0.0, scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: """Flash Attention function. It supports 4 mask type. 1. custom mask: recv attention_mask 2. padded mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices 3. causal mask: recv attention_mask, attention_mask_type 4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices Args: q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D] k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Skv, D] v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Skv, D] attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None. attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM. cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths of the sequences in the batch, used to index into q. Shape should be [B+1]. Defaults to None. cu_seqlens_kv (Optional[torch.Tensor], optional): The cumulative sequence lengths of the sequences in the batch, used to index into kv. Shape should be [B+1]. Defaults to None. max_seqlen_q (Optional[int], optional): Maximum query sequence length in the batch. Defaults to None. max_seqlen_kv (Optional[int], optional): Maximum key/value sequence length in the batch. Defaults to None. indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from the flattened input sequence. Shape should be [NUM_TOKENS]. Defaults to None. dropout_p (float, optional): Dropout probability. Defaults to 0.0. scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None. Returns: torch.Tensor: Output tensor. Shape should be [B, nHeads, Sq, D] """ # known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan # this case is usaul when padding mask is used and self attention is performed # thus, we don't use sdpa when padding mask is used # sanity check if attention_mask is not None: assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor." if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL): assert ( cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None and q_indices is None and kv_indices is None ) if attention_mask_type == AttnMaskType.CUSTOM: assert not torch.all(attention_mask != 0, dim=-1).any() elif attention_mask_type in ( AttnMaskType.PADDED, AttnMaskType.PADDED_CAUSAL, ): assert ( cu_seqlens_q is not None and cu_seqlens_kv is not None and max_seqlen_q is not None and max_seqlen_kv is not None and q_indices is not None and kv_indices is not None ) else: # if attention_mask is None, attention_mask_type should be the default value assert attention_mask_type == AttnMaskType.CUSTOM # kernel dispatch b, _, s_q, _ = q.shape b, _, s_kv, _ = v.shape element_size = torch.tensor([], dtype=q.dtype).element_size() memory_size = s_q * s_kv * element_size mask_type = attention_mask_type if attention_mask is not None else None attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size) is_causal = attention_mask is not None and attention_mask_type in ( AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL, ) return attn_func( q, k, v, dropout_p=dropout_p, scale=scale, attention_mask=attention_mask, is_causal=is_causal, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, q_indices=q_indices, kv_indices=kv_indices, ) def _load_varlen_helpers(): """Helper to load functions for padding and unpadding packed sequences. Use only when flash attn is installed """ global _pad_input, _unpad_input # Flash attn claims this is more efficient than torch's bool indexing due to avoiding # broadcast if _pad_input is None or _unpad_input is None: try: from flash_attn.bert_padding import index_first_axis, pad_input def unpad_input(hidden_states: torch.Tensor, indices: torch.Tensor): return index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices) _pad_input = pad_input _unpad_input = unpad_input except ImportError as e: raise RuntimeError( f"Flash Attention is not installed. You can install it via 'pip install flash-attn --no-build-isolation'" ) from e def _load_flash_attn(): """A light-weight loader to check whether flash-attn is installed. Can't use ColoAttention._dispatch_kernel because we mutate the backward pass """ global _flash_attn_forward, _flash_attn_backward if _flash_attn_forward is None or _flash_attn_backward is None: try: from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward except ImportError as e: raise RuntimeError( f"Flash Attention is not installed. You can install it via 'pip install flash-attn --no-build-isolation'" ) from e _load_varlen_helpers() # NOTE: This can cause spawned processes to hang on exit # with python 3.9 @torch.compile() def _rescale_out_lse(out, block_out, lse, block_lse): """ Compute the new attention denominator: exp(lse) + exp(block_lse) = exp(max_scale) * (exp(min_scale - max_scale) + 1) Args: out: (T, H, D) block_out: (T, H, D) lse: (H, T, 1) block_lse: (H, T, 1) """ # min_scale = torch.min(lse, block_lse) # max_scale = torch.max(lse, block_lse) # new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) # NOTE: directly assigning to .data here is buggy # probably due to casting dtypes/strides new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) new_block_lse = torch.exp(block_lse - new_lse) out = (torch.exp(lse - new_lse) * out + new_block_lse * block_out).to(out) lse = new_lse # Equivalent to the above # See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 # out = (out - F.sigmoid(block_lse - lse) * (out - block_out)) # lse = (lse - F.logsigmoid(lse - block_lse)) return out, lse class RingAttention(torch.autograd.Function): """Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context` (https://arxiv.org/abs/2310.01889). For load-balancing we adopted the "zigzag" attention scheme from https://github.com/zhuzilin/ring-flash-attention/tree/main For portable integration with more models, we don't follow the spirit of "block-wise FNN" in the original paper, which requires fusing FFN with the Flash Attention kernel/function (see https://arxiv.org/pdf/2305.19370; implemented in Jax and not optimized). We adopt the double ring topology from LoongTrain (https://arxiv.org/pdf/2406.18485) to fully utilize available NICs on each node, by computing attention within a inner ring first and then sending all KVs to the next ring at once. """ # Globle cache to avoid recomputation for same-lengthed sequences CU_SEQLENS: torch.Tensor = None # [B+1] TOTAL_SEQLEN: int = None HALF_INDICES: Tuple = None SUPPORTED_MASK_TYPES = (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL) ATTN_DONE: torch.cuda.Event = None SP_STREAM: torch.cuda.Stream = None SP_GROUP: dist.ProcessGroup = None # duplicate process group for concurrent NCCL streams # while both PyTorch and NCCL warns(https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7) # against this, in practice it seems to work fine. INNER_RING_GROUP: dist.ProcessGroup = None INNER_RING_GROUP_COPY: dist.ProcessGroup = None INTER_RING_GROUP: dist.ProcessGroup = None INTER_RING_GROUP_COPY: dist.ProcessGroup = None @staticmethod def get_double_ring_groups(sp_group, tp_group, inner_ring_size=None): """ Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size shouldn't be larger than the number of NICs on each node. Args: sp_group (dist.ProcessGroup): Process group for sequence parallelism inner_ring_size (Optional[int], optional): Inner ring size. Defaults to None. Returns: Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. """ sp_size = dist.get_world_size(sp_group) tp_size = dist.get_world_size(tp_group) sp_rank = dist.get_rank(sp_group) if inner_ring_size is None: if torch.cuda.device_count() >= dist.get_world_size(): # single node, no need to consider NICs return sp_group, sp_group if sp_size <= 4: inner_ring_size = min(2, sp_size) else: inner_ring_size = min(4, sp_size) else: assert ( inner_ring_size <= sp_size and sp_size % inner_ring_size == 0 ), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}" if inner_ring_size == sp_size: return sp_group, sp_group assert ( sp_size % inner_ring_size == 0 ), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}" logger = get_dist_logger() logger.info( f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!", ranks=[0], ) inner_ring_group = None inter_ring_group = None world_size = dist.get_world_size() rank = dist.get_rank() inner_rings = world_size // sp_size num_rings = sp_size // inner_ring_size if tp_size > 1: for i in range(inner_rings): for j in range(sp_size // tp_size): # find inner ring group in one sp group ranks = list(range(j + i * sp_size, j + (i + 1) * sp_size, tp_size)) group = dist.new_group(ranks) if rank in ranks: inner_ring_group = group for i in range(inner_rings): for j in range(sp_size // tp_size): ranks = list(range(j + i * (sp_size // tp_size), inner_rings + (i + 1) * sp_size, sp_size)) group = dist.new_group(ranks) if rank in ranks: inter_ring_group = group else: # Create inner ring groups for i in range(inner_ring_size): ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) group = dist.new_group(ranks) if sp_rank in ranks: inner_ring_group = group # Create inter ring groups for i in range(num_rings): ranks = list(range(i, sp_size, num_rings)) group = dist.new_group(ranks) if sp_rank in ranks: inter_ring_group = group return inner_ring_group, inter_ring_group @staticmethod def attention( q, # (B, H, Sq, D) k, v, sp_group, attention_mask_type, cu_seqlens=None, max_seqlen=None, valid_indices=None, dropout_p=0.0, softmax_scale=None, deterministic=False, return_softmax=False, inner_ring_size=None, tp_group=None, **kwargs, ): """ Ring Attention forward pass supporting variable-length sequences. When using varlen mode, each sequence in the batch should have length divisible by sp_size * 2. Args: q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D] k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D] v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D] sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism sp_tream (torch.cuda.Stream): An different stream for output correction. cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths of the sequences in the batch, used to index into q. Shape should be [B+1]. max_seqlen (Optional[int], optional): Maximum query sequence length in the batch. valid_indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from get_pad_info. Shape should be [t]. dropout_p (float, optional): Dropout probability. Defaults to 0.0. softmax_scale (Optional[float], optional): Scaling factor applied prior to softmax. deterministic (bool, optional): Whether to force deterministic backward pass. See https://github.com/Dao-AILab/flash-attention/issues/349 return_softmax (bool, optional): Whether to return the softmax denominator (logsumexp). inner_ring_size (Optional[int], optional): Inner ring size of the 2D ring. By default use a heuristic to decide. Returns: out: Output tensor of shape [B, nHeads, Sq, D] or [T, nHeads, D] if pad_output is False. softmax_lse: (if return_softmax is True) Softmax denominator (logsumexp). Shape should be [total_q_seqlen, nHeads] """ # Check input args _load_flash_attn() if RingAttention.ATTN_DONE is None: RingAttention.ATTN_DONE = torch.cuda.Event() if RingAttention.SP_STREAM is None: RingAttention.SP_STREAM = torch.cuda.Stream() assert ( q.shape[2] == k.shape[2] ), "Q, K and V having different sequence lengths (inference or cross-attn)\ is not supported yet in training." assert ( attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES ), f"Mask type {attention_mask_type} is not supported yet." clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg)) if RingAttention.SP_GROUP is not sp_group: RingAttention.SP_GROUP = sp_group inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups( sp_group, tp_group, inner_ring_size ) RingAttention.INNER_RING_GROUP = inner_ring_group RingAttention.INTER_RING_GROUP = inter_ring_group else: inner_ring_group = RingAttention.INNER_RING_GROUP inter_ring_group = RingAttention.INTER_RING_GROUP # (B, H, Sq, D) -> (B, Sq, H, D) q, k, v = [x.transpose(1, 2).contiguous() for x in (q, k, v)] pad_output = q.dim() == 4 # Get sequence length info for varlen forward if attention_mask_type == AttnMaskType.CAUSAL: # All sequences share the same length b, sq, h, d = q.shape max_seqlen = sq # Cache to avoid recreation for a single sequence if sq * b == RingAttention.TOTAL_SEQLEN: cu_seqlens = RingAttention.CU_SEQLENS else: cu_seqlens = torch.arange(0, b * sq + 1, sq, device=q.device, dtype=torch.int32) RingAttention.TOTAL_SEQLEN = b * sq # "Packed" mode where sequences of different lengths are packed into [total_q_seqlen, H, D] elif attention_mask_type == AttnMaskType.PADDED_CAUSAL: assert ( cu_seqlens is not None and max_seqlen is not None and valid_indices is not None ), "Packed mode requires pre-computed cu_seqlens and max_seq_len." if pad_output: b, sq, h, d = q.shape q, k, v = [_unpad_input(x, valid_indices) for x in (q, k, v)] out, softmax_lse = RingAttention.apply( q, k, v, sp_group, RingAttention.SP_STREAM, cu_seqlens, max_seqlen, dropout_p, softmax_scale, deterministic, return_softmax, attention_mask_type == AttnMaskType.PADDED_CAUSAL, inner_ring_group, inter_ring_group, tp_group, ) if attention_mask_type == AttnMaskType.PADDED_CAUSAL: if pad_output: out = _pad_input(out, valid_indices, b, sq) # (T, ...) -> (B, Sq, ...) out = out.transpose(1, 2) # (B, Sq, H, D) -> (B, H, Sq, D) else: out = out.transpose(1, 2) if return_softmax: return out, softmax_lse return out @staticmethod def forward( ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sp_group: dist.ProcessGroup, sp_stream: torch.cuda.Stream, cu_seqlens: torch.Tensor, max_seqlen: int, dropout_p: float = 0.0, softmax_scale: Optional[float] = None, deterministic: Optional[bool] = False, return_softmax: Optional[bool] = False, is_packed: Optional[bool] = False, inner_ring_group: Optional[dist.ProcessGroup] = None, inter_ring_group: Optional[dist.ProcessGroup] = None, tp_group: Optional[dist.ProcessGroup] = None, ): cu_seqlens_q = cu_seqlens_kv = cu_seqlens max_seqlen_q = max_seqlen_kv = max_seqlen cu_seqlens_half = cu_seqlens // 2 max_seqlen_half = max_seqlen // 2 misc_kwargs = { "window_size": (-1, -1), "alibi_slopes": None, "softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale, "dropout_p": dropout_p, "block_table": None, "softcap": 0.0, "return_softmax": False, } if ( RingAttention.HALF_INDICES is not None and cu_seqlens.shape == RingAttention.CU_SEQLENS.shape and (cu_seqlens == RingAttention.CU_SEQLENS).all() ): half_idx_front, half_idx_back = RingAttention.HALF_INDICES else: half_idx_front = get_half_index(cu_seqlens, front=True) half_idx_back = get_half_index(cu_seqlens, front=False) RingAttention.HALF_INDICES = (half_idx_front, half_idx_back) RingAttention.CU_SEQLENS = cu_seqlens if is_packed: t, h, d = q.shape else: b, sq, h, d = q.shape t = b * sq # Be careful about GQA/MQA in reshape q, k, v = [x.view(t, *x.shape[-2:]) for x in (q, k, v)] if inner_ring_group is None or inter_ring_group is None: # Use one ring if not specified inner_ring_group = inter_ring_group = sp_group sp_size = dist.get_world_size(sp_group) sp_rank = dist.get_rank(sp_group) # Attempt to achieve concurrent comm in the two-stream forward local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)] inter_ring_comm = RingComm(inter_ring_group) local_sp_size = dist.get_world_size(inner_ring_group) local_sp_rank = dist.get_rank(inner_ring_group) inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0 num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1 # Non-contiguous indexing copies to a new contiguous tensor, # so only do it once if sp_rank != sp_size - 1: q1 = q[half_idx_back] # Pre-allocate double buffer for overlapping and receiving next step's inputs kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D) kv_buffers.append(torch.empty_like(kv_buffers[0])) # outputs out = None block_out = [None, None] softmax_lse = [None, None] block_softmax_lse = [None, None] # log sum exp, the denominator of softmax in attention rng_states = [None for _ in range(sp_size)] sp_streams = [torch.cuda.current_stream(), sp_stream] def _forward(q, k, v, causal): ( _, _, _, _, out, softmax_lse, _, rng_state, ) = _flash_attn_forward( q, k, v, cu_seqlens_q if q.shape[0] == t else cu_seqlens_half, cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half, max_seqlen_q if q.shape[0] == t else max_seqlen_half, max_seqlen_kv if k.shape[0] == t else max_seqlen_half, causal=causal, **misc_kwargs, ) return out, softmax_lse, rng_state def _kv_comm(i): # Avoid overwriting attn input when it shares mem with buffer if not RingAttention.ATTN_DONE.query(): kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) if i < local_sp_size - 1: local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) def _local_ring_forward(): # (Hopefully) overlap output correction with next flash attn for i in range(local_sp_size): with torch.cuda.stream(sp_streams[i % 2]): # Wait for current kv from prev rank # NOTE: waiting outside the current stream will NOT correctly synchronize. if i > 0: local_kv_comms[(i + 1) % 2].wait() if i == 0: _kv_comm(i) if i == 0: # Compute with local KV; no mask kv_block = kv_buffers[0] q_block = q (block_out[i % 2], block_softmax_lse[i % 2], rng_states[i]) = _forward( # (T, H, D) # (H, T) q_block, kv_block[0], kv_block[1], causal=True ) elif i <= local_sp_rank: # Received the "surrounding" kv chunks # Drop the second half of received kv # (2, t // 2, H, D) kv_block = kv_buffers[i % 2][:, half_idx_front] q_block = q ( block_out[i % 2], # (T, H, D) block_softmax_lse[i % 2], # (H, T) rng_states[i], ) = _forward(q_block, kv_block[0], kv_block[1], causal=False) else: # Received the inner kv chunks # Drop the first half of q kv_block = kv_buffers[i % 2] q_block = q1 ( block_out[i % 2], # (T, H, D) block_softmax_lse[i % 2], # (H, T) rng_states[i], ) = _forward(q_block, kv_block[0], kv_block[1], causal=False) RingAttention.ATTN_DONE.record() # Pipeline the next KV comm with output correction instead of the next flash attn # to minimize idle time when comm takes longer than attn. _kv_comm(i + 1) block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() ) # (H, T) -> (T, H, 1) assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1] # Output and log sum exp correction. Ideally overlap this with the next flash attn kernel. # In reality this always finishes before next flash attn; no need for extra sync. if i == 0: out = block_out[0] softmax_lse = block_softmax_lse[0] elif i <= local_sp_rank: out, softmax_lse = _rescale_out_lse( out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2] ) else: out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse( out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2] ) torch.cuda.current_stream().wait_stream(sp_stream) return out, softmax_lse def _other_ring_forward(ring_num_idx, out, softmax_lse): # Loop through the inner ring after receiving # all new KVs from the previous inner ring for i in range(local_sp_size): with torch.cuda.stream(sp_streams[i % 2]): # Send & recv KV if i > 0: local_kv_comms[(i + 1) % 2].wait() if i == 0: _kv_comm(i) if ring_num_idx > inter_ring_rank: kv_block = kv_buffers[i % 2] ( block_out[i % 2], block_softmax_lse[i % 2], rng_states[i + local_sp_size * ring_num_idx], ) = _forward(q1, kv_block[0], kv_block[1], causal=False) RingAttention.ATTN_DONE.record() _kv_comm(i + 1) block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() ) out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse( out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2] ) else: kv_block = kv_buffers[i % 2][:, half_idx_front] ( block_out[i % 2], block_softmax_lse[i % 2], rng_states[i + local_sp_size * ring_num_idx], ) = _forward(q, kv_block[0], kv_block[1], causal=False) RingAttention.ATTN_DONE.record() _kv_comm(i + 1) block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() ) out, softmax_lse = _rescale_out_lse( out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2] ) torch.cuda.current_stream().wait_stream(sp_stream) return out, softmax_lse # Send and recv KV between rings at once to maximize NIC util. inter_ring_kv = None for ring_num_idx in range(num_rings): if ring_num_idx > 0: inter_ring_comm.wait() # Reset indices kv_buffers[0] = inter_ring_kv if ring_num_idx < num_rings - 1: if ring_num_idx == 0: to_send = kv_buffers[0] else: # The last received KV to_send = kv_buffers[(local_sp_size - 1) % 2] inter_ring_kv = inter_ring_comm.send_recv(to_send) if ring_num_idx == 0: out, softmax_lse = _local_ring_forward() else: out, softmax_lse = _other_ring_forward(ring_num_idx, out, softmax_lse) out = out.to(q.dtype) if not is_packed: out = out.view(b, sq, h, d) q, k, v = [x.view(b, sq, *x.shape[-2:]) for x in (q, k, v)] # (T, H, D) -> (B, Sq, H, D) softmax_lse = softmax_lse.squeeze(-1) ctx.sp_group = sp_group ctx.max_seqlen_q = ctx.max_seqlen_kv = max_seqlen misc_kwargs["deterministic"] = deterministic del misc_kwargs["return_softmax"] ctx.misc_kwargs = misc_kwargs ctx.is_packed = is_packed ctx.kv_group = inner_ring_group ctx.inter_kv_group = inter_ring_group ctx.save_for_backward( q, k, v, out, softmax_lse.transpose(0, 1).contiguous(), # (T, H) -> (H, T) cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back, *rng_states, ) if return_softmax: return out, softmax_lse return out, None def backward(ctx, dout, _): """ During backward, we accumulate q grads on each rank locally, but iterate kv and their grads over all ranks for accumulation. """ (q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9] rng_states = ctx.saved_tensors[9:] is_packed = ctx.is_packed max_seqlen_q = ctx.max_seqlen_q max_seqlen_kv = ctx.max_seqlen_kv cu_seqlens_half = cu_seqlens_q // 2 max_seqlen_half = max_seqlen_q // 2 misc_kwargs = ctx.misc_kwargs del misc_kwargs["block_table"] assert ( out.shape == dout.shape == q.shape ), f"out {out.shape} and dout {dout.shape} should have the same shape ({q.shape})." if is_packed: t, h, d = q.shape else: b, sq, h, d = q.shape t = b * sq q, k, v, out, dout = [x.view(t, *x.shape[-2:]) for x in (q, k, v, out, dout)] # Sequence parallel args sp_group = ctx.sp_group local_kv_group = ctx.kv_group inter_kv_group = ctx.inter_kv_group local_sp_rank = dist.get_rank(sp_group) sp_size = dist.get_world_size(sp_group) # Using separate streams (pg) for concurrent kv and dkv comm may # cause NCCL "software caused connection abort" here... local_kv_comm = RingComm(local_kv_group) local_dkv_comm = RingComm(local_kv_group) inter_kv_comm = RingComm(inter_kv_group) inter_dkv_comm = RingComm(inter_kv_group) local_sp_size = dist.get_world_size(local_kv_group) local_sp_rank = dist.get_rank(local_kv_group) if dist.get_world_size(inter_kv_group) != sp_size: num_rings = dist.get_world_size(inter_kv_group) inter_ring_rank = dist.get_rank(inter_kv_group) else: num_rings = 1 inter_ring_rank = 0 if local_sp_rank != sp_size - 1: softmax_lse1 = softmax_lse[:, half_idx_back] dout = dout.contiguous() # Double comm buffers for sending and receiving kv kv_buffers = [torch.stack((k, v))] # (2, T, H, D) kv_buffers.append(torch.empty_like(kv_buffers[0])) dq = None # (T, H, D) # Intermediate outputs dq_block = torch.empty_like(q) # (T, H, D) dk_block = torch.empty_like(k) # (T, H, D) dv_block = torch.empty_like(v) # (T, H, D) dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D) del k, v def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal): _flash_attn_backward( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q if dq.shape[0] == t else cu_seqlens_half, cu_seqlens_kv if dk.shape[0] == t else cu_seqlens_half, max_seqlen_q if dq.shape[0] == t else max_seqlen_half, max_seqlen_kv if dk.shape[0] == t else max_seqlen_half, causal=causal, rng_state=rng_state, **misc_kwargs, ) # NOTE: We avoid using two streams due to doubled buffers # and that backward is more communication intensive. def _local_ring_backward(): for i in range(local_sp_size): if i > 0: local_kv_comm.wait() if i < local_sp_size - 1: # Send kv to next rank for backward local_kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) if i == 0: # Backward with local kv k_, v_ = kv_buffers[i % 2] q_, dout_, out_ = q, dout, out dq_, dk_, dv_ = dq_block, dk_block, dv_block _backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_states[i], causal=True) elif i <= local_sp_rank: # Drop the second half of kv # (T, H, D) -> (T // 2, H, D) k_, v_ = [x[half_idx_front] for x in kv_buffers[i % 2]] dk_, dv_ = [x[: t // 2] for x in (dk_block, dv_block)] dq_, q_, out_, dout_ = (dq_block, q, out, dout) _backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_states[i], causal=False) else: # Drop the first half of q k_, v_ = kv_buffers[i % 2] dk_, dv_ = dk_block, dv_block q_, out_, dout_ = [x[half_idx_back] for x in (q, out, dout)] dq_ = dq_block[: t // 2] _backward(dout_, q_, k_, v_, out_, softmax_lse1, dq_, dk_, dv_, rng_states[i], causal=False) # Accumulate grads if i == 0: dq = dq_block.float() dkv_buffers[i % 2][0] = dk_block.float() dkv_buffers[i % 2][1] = dv_block.float() else: # Accumulate local dq if i <= local_sp_rank: dq += dq_ # (T, H, D) else: dq[half_idx_back] += dq_ # Wait for mobile kv grad accumulators local_dkv_comm.wait() if i <= local_sp_rank: # q blocks "surrounded" by kv blocks dkv_buffers[i % 2][0][half_idx_front] += dk_ dkv_buffers[i % 2][1][half_idx_front] += dv_ else: # q blocks "surrounding" kv blocks dkv_buffers[i % 2][0] += dk_ dkv_buffers[i % 2][1] += dv_ local_dkv_comm.send_recv(send_tensor=dkv_buffers[i % 2], recv_tensor=dkv_buffers[(i + 1) % 2]) local_dkv_comm.wait() dkv_recv = dkv_buffers[local_sp_size % 2] dkv_send = dkv_buffers[(local_sp_size - 1) % 2] return dq, dkv_recv, dkv_send def _other_ring_backward(ring_num_idx, dq): if ring_num_idx > inter_ring_rank: # Indexing is expensive q_, out_, dout_ = [x[half_idx_back] for x in (q, out, dout)] else: q_, out_, dout_ = (q, out, dout) for i in range(local_sp_size): if i > 0: local_kv_comm.wait() if i < local_sp_size - 1: local_kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) rng_state = rng_states[i + local_sp_size * ring_num_idx] if ring_num_idx > inter_ring_rank: k_, v_ = kv_buffers[i % 2] dk_, dv_ = dk_block, dv_block dq_ = dq_block[: t // 2] _backward(dout_, q_, k_, v_, out_, softmax_lse1, dq_, dk_, dv_, rng_state, causal=False) dq[half_idx_back] += dq_ if i > 0: local_dkv_comm.wait() else: inter_dkv_comm.wait() dkv_buffers[i % 2][0] += dk_ dkv_buffers[i % 2][1] += dv_ else: k_, v_ = [x[half_idx_front] for x in kv_buffers[i % 2]] dk_, dv_ = [x[: t // 2] for x in (dk_block, dv_block)] dq_ = dq_block _backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_state, causal=False) dq += dq_ if i > 0: local_dkv_comm.wait() else: inter_dkv_comm.wait() dkv_buffers[i % 2][0][half_idx_front] += dk_ dkv_buffers[i % 2][1][half_idx_front] += dv_ local_dkv_comm.send_recv(send_tensor=dkv_buffers[i % 2], recv_tensor=dkv_buffers[(i + 1) % 2]) local_dkv_comm.wait() dkv_recv = dkv_buffers[local_sp_size % 2] dkv_send = dkv_buffers[(local_sp_size - 1) % 2] return dq, dkv_recv, dkv_send inter_ring_kv = None for ring_num_idx in range(num_rings): if ring_num_idx > 0: inter_kv_comm.wait() kv_buffers[0] = inter_ring_kv if ring_num_idx < num_rings - 1: # Re-allocate a buffer in each inter-ring step inter_ring_kv = inter_kv_comm.send_recv(kv_buffers[0]) if ring_num_idx == 0: dq, dkv_recv, dkv_send = _local_ring_backward() else: dq, dkv_recv, dkv_send = _other_ring_backward(ring_num_idx, dq) if num_rings > 1: # Reuse the local buffers inter_dkv_comm.send_recv(send_tensor=dkv_recv, recv_tensor=dkv_send) # Reset indices dkv_buffers[0] = dkv_send dkv_buffers[1] = dkv_recv if ring_num_idx == num_rings - 1: inter_dkv_comm.wait() dkv_recv = dkv_buffers[0] dq, dk, dv = [x.to(q.dtype) for x in (dq, *dkv_recv)] if not is_packed: dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)] return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None) @staticmethod def prepare_varlen_batch( attention_mask: torch.Tensor, sp_group: dist.ProcessGroup, inputs_embeds: torch.Tensor = None, position_ids: Optional[torch.Tensor] = None, is_label: bool = False, is_2d: bool = True, ): """ Preprocess a batch of padded sequence by splitting input sequence by sp_size sequence-wise and packing them into one sequence. Updates the mask info accordingly. Args: attention_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked. sp_group (dist.ProcessGroup): Process group for sequence parallelism inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...] position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None. is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first token of each sequence. is_2d (bool, optional): Whether to return 2D outputs padded to max_seqlen // sp_size or flatten the batch dim to a packed 1d sequence. Contingent on model forward shape definitions. Returns: torch.Tensor: Packed input embeddings of shape [B, Sq // sp_size, ...]. Dict[str, Any]: A dictionary containing mask info. torch.Tensor: Packed position ids of shape [..., Sq // sp_size]. """ _load_varlen_helpers() sp_size = dist.get_world_size(group=sp_group) sp_rank = dist.get_rank(group=sp_group) mask_info = {} mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(attention_mask, return_indices=False) # Unpad, split seq-wise, then pad back to (B, max_seqlen // sp_size) # Split mask to compute local nonzero position indices # (B, Sq) -> (B, max_seqlen // sp_size) attention_mask = attention_mask[:, : mask_info["max_seqlen"]] if inputs_embeds is not None: inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]] inputs_embeds = split_varlen_zigzag( inputs_embeds, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_2d=is_2d, is_label=is_label, ) attention_mask = split_varlen_zigzag( attention_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_2d=is_2d ) if position_ids is not None: indices = torch.tensor([sp_rank, 2 * sp_size - sp_rank - 1], device=inputs_embeds.device) position_ids = ( position_ids[..., : mask_info["max_seqlen"]] # unpad .view(-1, sp_size * 2, mask_info["max_seqlen"] // (sp_size * 2)) .index_select(-2, indices) .view(-1, mask_info["max_seqlen"] // sp_size) ) mask_info["max_seqlen"] //= sp_size mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() mask_info["cu_seqlens"] //= sp_size mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL return inputs_embeds, mask_info, position_ids