diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index bbd99d162..3202ebf25 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -422,13 +422,18 @@ class RingAttention(torch.autograd.Function): ATTN_DONE: torch.cuda.Event = None SP_STREAM: torch.cuda.Stream = None SP_GROUP: dist.ProcessGroup = None - # duplicate process group for concurrent NCCL streams - # while both PyTorch and NCCL warns(https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7) - # against this, in practice it seems to work fine. + + # NOTE: Duplicating PGs for concurrent NCCL streams is a risky hack -- while it may increase throughput, + # both PyTorch and NCCL warn against this. (https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7) + # LoongTrain's original double ring impl. uses concurrent PGs + # (https://github.com/InternLM/InternEvo/blob/e52f2ffc9acf818e8f2b1f97dfc69ceb2f06e154/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py#L192) + # but I confirmed with Pytorch developers this can cause obscure "Software caused connection abort" errors. + # (https://github.com/pytorch/pytorch/issues/132852) + # NOTE: In general, a smarter idea is put as many P2P calls as possible into one `batch_isend_irecv`. INNER_RING_GROUP: dist.ProcessGroup = None - INNER_RING_GROUP_COPY: dist.ProcessGroup = None + # INNER_RING_GROUP_COPY: dist.ProcessGroup = None INTER_RING_GROUP: dist.ProcessGroup = None - INTER_RING_GROUP_COPY: dist.ProcessGroup = None + # INTER_RING_GROUP_COPY: dist.ProcessGroup = None @staticmethod def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None): @@ -626,7 +631,13 @@ class RingAttention(torch.autograd.Function): inner_ring_group: Optional[dist.ProcessGroup] = None, inter_ring_group: Optional[dist.ProcessGroup] = None, ): - + """ + Forward supporting both packed (varlen) and batched(fixed length, no padding) sequences. + No separate version for batched seq (hard to maintain), which incurs + some overhead in sequence splitting due to python for loops. + Uses two CUDA streams to overlap softmax denominator correction with next flash attn + (see comments below). + """ cu_seqlens_q = cu_seqlens_kv = cu_seqlens max_seqlen_q = max_seqlen_kv = max_seqlen cu_seqlens_half = cu_seqlens // 2 @@ -668,7 +679,8 @@ class RingAttention(torch.autograd.Function): sp_size = dist.get_world_size(sp_group) sp_rank = dist.get_rank(sp_group) - # Attempt to achieve concurrent comm in the two-stream forward + + # Create communicators corresponding to two CUDA streams local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)] inter_ring_comm = RingComm(inter_ring_group) local_sp_size = dist.get_world_size(inner_ring_group) @@ -676,7 +688,7 @@ class RingAttention(torch.autograd.Function): inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0 num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1 - # Non-contiguous indexing copies to a new contiguous tensor, + # Any type of indexing(but not slicing) copies to a new contiguous tensor, # so only do it once if sp_rank != sp_size - 1: q1 = q[half_idx_back] @@ -693,6 +705,7 @@ class RingAttention(torch.autograd.Function): rng_states = [None for _ in range(sp_size)] sp_streams = [torch.cuda.current_stream(), sp_stream] + # Helper to pass args to FA def _forward(q, k, v, causal): ( _, @@ -723,6 +736,7 @@ class RingAttention(torch.autograd.Function): if i < local_sp_size - 1: local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) + # Forward within a node def _local_ring_forward(): # (Hopefully) overlap output correction with next flash attn for i in range(local_sp_size): @@ -731,6 +745,8 @@ class RingAttention(torch.autograd.Function): # NOTE: waiting outside the current stream will NOT correctly synchronize. if i > 0: local_kv_comms[(i + 1) % 2].wait() + + # Prefetch if i == 0: _kv_comm(i) @@ -764,15 +780,22 @@ class RingAttention(torch.autograd.Function): ) = _forward(q_block, kv_block[0], kv_block[1], causal=False) RingAttention.ATTN_DONE.record() # Pipeline the next KV comm with output correction instead of the next flash attn - # to minimize idle time when comm takes longer than attn. + # kernel, to minimize bubble when comm takes longer than attn. _kv_comm(i + 1) block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() ) # (H, T) -> (T, H, 1) assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1] - # Output and log sum exp correction. Ideally overlap this with the next flash attn kernel. - # In reality this always finishes before next flash attn; no need for extra sync. + + # Output and log sum exp correction. + # Ideally overlap this with the next flash attn kernel, + # since attn uses Tensor Core and rescale is element-wise, memory-bound and uses CUDA cores. + # (NOTE that this is the same as ping-pong scheduling idea in FA3) + # TODO However sometimes while the GPU has scheduled the next kernel, + # it's reluctant to launch it in overlap. Some potential causes: + # 1. need lower-level CUDA scheduling 2. further benchmark against Megatron-LM + # 3. register spilling by FA kernel. if i == 0: out = block_out[0] softmax_lse = block_softmax_lse[0] @@ -788,15 +811,17 @@ class RingAttention(torch.autograd.Function): torch.cuda.current_stream().wait_stream(sp_stream) return out, softmax_lse + # Forward for inter-node (the outer ring in 2D ring) def _other_ring_forward(ring_num_idx, out, softmax_lse): # Loop through the inner ring after receiving - # all new KVs from the previous inner ring + # all new KVs from another ring for i in range(local_sp_size): with torch.cuda.stream(sp_streams[i % 2]): # Send & recv KV if i > 0: local_kv_comms[(i + 1) % 2].wait() + # Prefetch if i == 0: _kv_comm(i) @@ -893,7 +918,8 @@ class RingAttention(torch.autograd.Function): def backward(ctx, dout, _): """ During backward, we accumulate q grads on each rank locally, but iterate kv and their grads - over all ranks for accumulation. + over all ranks for accumulation. We avoid using two streams due to backward using doubled + buffers and more comm cost. """ (q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9] rng_states = ctx.saved_tensors[9:] @@ -925,7 +951,7 @@ class RingAttention(torch.autograd.Function): local_sp_rank = dist.get_rank(sp_group) sp_size = dist.get_world_size(sp_group) - # Using separate streams (pg) for concurrent kv and dkv comm may + # NOTE: Using separate streams (PG) for concurrent kv and dkv comm may # cause NCCL "software caused connection abort" here... local_kv_comm = RingComm(local_kv_group) local_dkv_comm = RingComm(local_kv_group) @@ -957,6 +983,7 @@ class RingAttention(torch.autograd.Function): dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D) del k, v + # Helper to pass args to FA def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal): _flash_attn_backward( dout, @@ -977,8 +1004,7 @@ class RingAttention(torch.autograd.Function): **misc_kwargs, ) - # NOTE: We avoid using two streams due to doubled buffers - # and that backward is more communication intensive. + # Backward within a node def _local_ring_backward(): for i in range(local_sp_size): if i > 0: @@ -1041,6 +1067,7 @@ class RingAttention(torch.autograd.Function): dkv_send = dkv_buffers[(local_sp_size - 1) % 2] return dq, dkv_recv, dkv_send + # Backward for inter-node (the outer ring in 2D ring) def _other_ring_backward(ring_num_idx, dq): if ring_num_idx > inter_ring_rank: # Indexing is expensive @@ -1125,34 +1152,34 @@ class RingAttention(torch.autograd.Function): @staticmethod def prepare_varlen_batch( - attention_mask: torch.Tensor, + padding_mask: torch.Tensor, sp_group: dist.ProcessGroup, inputs_embeds: torch.Tensor = None, position_ids: Optional[torch.Tensor] = None, is_label: bool = False, - is_2d: bool = True, + is_batched_seq: bool = True, ): + # TODO: support setting a batch dim (fix packing length) for packed mode, so that + # DP can be used (needs to modify dataloader too) """ Preprocess a batch of padded sequence by splitting input sequence by sp_size - sequence-wise and packing them into one sequence. Updates the mask info accordingly. + seq-wise and packing them into one sequence. Updates the mask info accordingly. Args: - attention_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked. + padding_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked. sp_group (dist.ProcessGroup): Process group for sequence parallelism inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...] position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None. is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first token of each sequence. - is_2d (bool, optional): Whether to return 2D outputs padded to max_seqlen // sp_size or flatten - the batch dim to a packed 1d sequence. Contingent on model forward shape definitions. + is_batched_seq (bool, optional): If True, then the input is a batch of (potentially padded) sequences + of shape [B, Sq, ...]; else a packed sequence of shape [T, ...]. Returns: - torch.Tensor: - Packed input embeddings of shape [B, Sq // sp_size, ...]. - - Dict[str, Any]: + inputs_embeds (torch.Tensor): + Packed input embeddings of shape [B, Sq // sp_size, ...] if is_batched_seq, else [T, ...]. + mask_info (Dict[str, Any]): A dictionary containing mask info. - - torch.Tensor: + position_ids (torch.Tensor): Packed position ids of shape [..., Sq // sp_size]. """ @@ -1160,12 +1187,11 @@ class RingAttention(torch.autograd.Function): sp_size = dist.get_world_size(group=sp_group) sp_rank = dist.get_rank(group=sp_group) mask_info = {} - mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(attention_mask, return_indices=False) + mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(padding_mask, return_indices=False) - # Unpad, split seq-wise, then pad back to (B, max_seqlen // sp_size) - # Split mask to compute local nonzero position indices + # Unpad, split seq-wise, then pad to (B, max_seqlen // sp_size) # (B, Sq) -> (B, max_seqlen // sp_size) - attention_mask = attention_mask[:, : mask_info["max_seqlen"]] + padding_mask = padding_mask[:, : mask_info["max_seqlen"]] if inputs_embeds is not None: inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]] inputs_embeds = split_varlen_zigzag( @@ -1173,11 +1199,12 @@ class RingAttention(torch.autograd.Function): mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], - is_2d=is_2d, + is_batched_seq=is_batched_seq, is_label=is_label, ) - attention_mask = split_varlen_zigzag( - attention_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_2d=is_2d + # Split mask to get local nonzero seq positions + padding_mask = split_varlen_zigzag( + padding_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_batched_seq=is_batched_seq ) if position_ids is not None: @@ -1190,7 +1217,7 @@ class RingAttention(torch.autograd.Function): ) mask_info["max_seqlen"] //= sp_size - mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + mask_info["valid_indices"] = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() mask_info["cu_seqlens"] //= sp_size mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL return inputs_embeds, mask_info, position_ids diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 4512e0c68..2df68e18c 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -295,8 +295,8 @@ def split_batch_zigzag( batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False ) -> Union[torch.Tensor, List[torch.Tensor]]: """ - Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask - in the causal setting will result in the preceding ranks having much less workload. + Split the input sequence batch . Naively spliting the attention mask in the causal setting + will result in the preceding ranks having much less workload. We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2). For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |. @@ -346,40 +346,42 @@ def split_varlen_zigzag( cu_seqlens: torch.Tensor, sp_group: ProcessGroup, max_seqlen: int = 0, - is_2d: bool = False, + is_batched_seq: bool = False, is_label: bool = False, ) -> Union[List[torch.Tensor], torch.Tensor]: - """Split each sequence in a batch of packed sequences in a zigzag fashion. - For each tensor in batch, return packed sequences if is_2d is False; - else return a padded batch of sequences. - + """Split a packed seq/batch of padded sequences in a Zigzag fashion. + Different from split_batch_zigzag, inputs here have variable sequence lengths. Args: - batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d. + batch (List[torch.Tensor]): Packed sequences of shape (T, ...), or (B, Sq, ...) if is_batched_seq, + where T is the total number of tokens. cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting. sp_group (ProcessGroup): The process group for sequence parallelism. max_seqlen (int): The maximum sequence length in the batch before splitting. - is_2d (bool): If True, then input has batch size and sequence length split into two dimensions. + is_batched_seq (bool): If True, then the input is a batch of sequences padded to the same len. is_label (bool): If True, mask out the first token in each sequence (). Returns: - batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size) - or (B, max_seqlen // sp_size, ...) if is_2d + batch (List[torch.Tensor]): Packed sequences of shape (T, ..) + or (B, max_seqlen // sp_size, ...) if is_batched_seq """ sp_size = dist.get_world_size(sp_group) sp_rank = dist.get_rank(sp_group) if sp_size == 1: return batch - if is_2d: + if is_batched_seq: assert max_seqlen > 0, "max_seqlen must be provided for 2D input" if isinstance(batch, torch.Tensor): batch = [batch] + # seq: (B, Sq, h, n) + # seq = seq[:, :rank * (seqlen // sp_size), ...] + for i, packed_seq in enumerate(batch): device = packed_seq.device dtype = packed_seq.dtype - if is_2d: + if is_batched_seq: assert max_seqlen % (sp_size * 2) == 0 # Recreate a padded tensor with the new max seqlen shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:]) @@ -398,7 +400,7 @@ def split_varlen_zigzag( seqlen % (2 * sp_size) == 0 ), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting" - if is_2d: + if is_batched_seq: seq = packed_seq[j][:seqlen] if is_label: # Shift one position to the right for next token prediction @@ -415,7 +417,7 @@ def split_varlen_zigzag( seq = seq.chunk(sp_size * 2) local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]]) - if is_2d: + if is_batched_seq: batch[i] = local_seq.contiguous() else: batch[i] = torch.cat(local_seq, dim=0)