diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a4132a507..f7217a8f1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,6 +12,7 @@ repos: hooks: - id: isort name: sort all imports (python) + args: ["--profile", "black"] # avoid conflict with black - repo: https://github.com/psf/black-pre-commit-mirror rev: 24.8.0 diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index e5acdb051..63427192f 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -32,7 +32,7 @@ from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackw from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer -from colossalai.shardformer.layer.utils import SeqParallelUtils +from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp from colossalai.shardformer.policies.base_policy import Policy from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.d_tensor.api import is_distributed_tensor @@ -42,7 +42,7 @@ from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_hand from .pp_plugin_base import PipelinePluginBase -SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] +SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"] PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} @@ -72,7 +72,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): self.dp_group = dp_group self.tp_group = tp_group self.sp_group = sp_group - self.use_dpp = use_ddp + self.use_ddp = use_ddp self.require_grad_sync = True self.overlap_allgather = overlap_allgather @@ -139,8 +139,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): # Disable automatic gradient synchronization. self.require_grad_sync = False try: - if self.use_dpp: - # If using data parallel processing (use_dpp), disable synchronization too. + if self.use_ddp: + # If using data parallel processing (use_ddp), disable synchronization too. with self.module.no_sync(): yield else: @@ -188,7 +188,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): """ if self.shard_config.enable_sequence_parallelism: - if self.shard_config.sequence_parallelism_mode == "all_to_all": + if self.shard_config.sequence_parallelism_mode in ["all_to_all", "ring_attn"]: return if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: @@ -970,6 +970,9 @@ class HybridParallelPlugin(PipelinePluginBase): enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism + inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn". + It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default. + """ def __init__( @@ -1017,6 +1020,7 @@ class HybridParallelPlugin(PipelinePluginBase): dp_outside: bool = True, overlap_p2p: bool = True, overlap_allgather: bool = False, + inner_ring_size: int = None, ) -> None: super().__init__() @@ -1041,9 +1045,11 @@ class HybridParallelPlugin(PipelinePluginBase): ) self.sp_size = 1 self.dp_size = dist.get_world_size() // (tp_size * pp_size) - elif self.sequence_parallelism_mode in ["all_to_all"]: + elif self.sequence_parallelism_mode in ["all_to_all", "ring_attn"]: self.sp_size = 1 if sp_size is None else sp_size self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size) + if self.sequence_parallelism_mode == "ring_attn": + enable_flash_attention = True else: self.dp_size = dist.get_world_size() // (tp_size * pp_size) assert ( @@ -1063,10 +1069,21 @@ class HybridParallelPlugin(PipelinePluginBase): self.enable_sequence_parallelism = enable_sequence_parallelism if dp_outside: self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 - self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) + if sequence_parallelism_mode == "ring_attn": + # Swap tp and sp since 2D Ring has better inter-node latency + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size) + self.sp_axis = 2 + self.tp_axis = 3 + else: + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) else: self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 - self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) + if sequence_parallelism_mode == "ring_attn": + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.sp_size, self.tp_size) + self.sp_axis = 2 + self.tp_axis = 3 + else: + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.stage_manager = None self.schedule = None @@ -1108,6 +1125,8 @@ class HybridParallelPlugin(PipelinePluginBase): ) else: raise NotImplementedError() + if sequence_parallelism_mode == "ring_attn": + assert parallel_output, "Ring Attention doesn't support gathering output yet." self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) @@ -1132,6 +1151,7 @@ class HybridParallelPlugin(PipelinePluginBase): parallel_output=parallel_output, make_vocab_size_divisible_by=make_vocab_size_divisible_by, gradient_checkpoint_config=gradient_checkpoint_config, + inner_ring_size=inner_ring_size, ) self.amp_config = dict( initial_scale=initial_scale, @@ -1216,15 +1236,15 @@ class HybridParallelPlugin(PipelinePluginBase): zero_stage = 0 if not isinstance(model, ModelWrapper): + # Shouldn't use pp (frequent grad accumulation) with torch ddp use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( - self.dp_size == 1 - and self.pp_size == 1 - and self.enable_sequence_parallelism - and self.sequence_parallelism_mode == "all_to_all" + self.dp_size == 1 and self.pp_size == 1 ) - # sync gradients across DP * SP ranks - if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": + + # Apply Hybrid ZeRO across DP * SP ranks + if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode): dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) + self.dp_size = get_world_size(dp_group) else: dp_group = self.dp_group model = HybridParallelModule( diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 0310df548..043e5c2b0 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -203,7 +203,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): return Path(checkpoint).mkdir(parents=True, exist_ok=True) - # Devices along the same dp_group share the same copies of model. # So only let the device with dp_rank == 0 save the model. if self.dp_rank != 0: @@ -643,14 +642,12 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): assert isinstance(model, ModelWrapper), "Please boost the model before saving!" model._force_wait_all_gather() model = model.unwrap() - if self.dp_rank != 0: return # The logic of collecting parameter shards along tp degree # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer. state_dict = model.state_dict() - if self.pp_size == 1: # When pipeline is not used, let master rank directly save the collected state_dict. if self.tp_rank == 0: @@ -660,7 +657,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): state_dict_list = [None for _ in range(self.pp_size)] dist.barrier(self.pp_group) dist.all_gather_object(state_dict_list, state_dict, self.pp_group) - # Only the master rank do the saving. if self.coordinator.is_master(): complete_state_dict = dict() diff --git a/colossalai/lazy/pretrained.py b/colossalai/lazy/pretrained.py index 736ffc5e4..226951598 100644 --- a/colossalai/lazy/pretrained.py +++ b/colossalai/lazy/pretrained.py @@ -62,7 +62,6 @@ def new_from_pretrained( config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) @@ -116,7 +115,6 @@ def new_from_pretrained( cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, - resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, @@ -195,7 +193,6 @@ def new_from_pretrained( "cache_dir": cache_dir, "force_download": force_download, "proxies": proxies, - "resume_download": resume_download, "local_files_only": local_files_only, "use_auth_token": use_auth_token, "user_agent": user_agent, @@ -312,7 +309,6 @@ def new_from_pretrained( pretrained_model_name_or_path, cache_dir=cache_dir, force_download=force_download, - resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, diff --git a/colossalai/legacy/moe/openmoe/model/openmoe_policy.py b/colossalai/legacy/moe/openmoe/model/openmoe_policy.py index ccd566b08..d5824afcb 100644 --- a/colossalai/legacy/moe/openmoe/model/openmoe_policy.py +++ b/colossalai/legacy/moe/openmoe/model/openmoe_policy.py @@ -171,7 +171,7 @@ class OpenMoeForCausalLMPolicy(OpenMoePolicy): policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm # TODO: recursively assign ep group foe all modules new_item = { OpenMoeForCausalLM: ModulePolicyDescription( diff --git a/colossalai/legacy/nn/layer/parallel_1d/_operation.py b/colossalai/legacy/nn/layer/parallel_1d/_operation.py index 8b8f04ccf..e892336bc 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_1d/_operation.py @@ -81,6 +81,9 @@ class LinearWithAsyncCommunication(torch.autograd.Function): handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True) # Delay the start of weight gradient computation shortly (3us) to have # all-reduce scheduled first and have GPU resources allocated + # TODO: This seems to only work if you add torch.cuda.Event.wait() + + # _ = torch.zeros(1, device=grad_output.device) grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None diff --git a/colossalai/logging/logger.py b/colossalai/logging/logger.py index eb5f28e2a..9f4b7a7b0 100644 --- a/colossalai/logging/logger.py +++ b/colossalai/logging/logger.py @@ -64,7 +64,10 @@ class DistributedLogger: self._logger.propagate = False DistributedLogger.__instances[name] = self - self.rank = dist.get_rank() if dist.is_initialized() else 0 + + @property + def rank(self): + return dist.get_rank() if dist.is_initialized() else 0 @staticmethod def __get_call_info(): diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index a21b45c44..412f3896f 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -286,7 +286,6 @@ class InterleavedSchedule(PipelineSchedule): # for the first stage, input_obj is None # for other stages, input_obj is the output of the previous stage containing hidden_states etc. # Only attention_mask from micro_batch is used - with self.stage_manager.switch_model_chunk_id(model_chunk_id): if isinstance(model_chunk, ModuleList): output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 7f0d0e349..03df67ae7 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -244,6 +244,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): output_obj = model_forward(model, micro_batch, input_obj) if self.stage_manager.is_last_stage(): loss = criterion(output_obj, micro_batch) / self.num_microbatches + if accum_loss is not None: accum_loss.add_(loss.detach()) if outputs is not None: diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 331e49729..8882a33c1 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,5 +1,5 @@ from ._operation import all_to_all_comm -from .attn import AttnMaskType, ColoAttention +from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D @@ -31,5 +31,7 @@ __all__ = [ "VocabParallelLMHead1D", "AttnMaskType", "ColoAttention", + "RingAttention", + "get_pad_info", "all_to_all_comm", ] diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 19da348e7..25983e0a9 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -2,6 +2,8 @@ import torch import torch.distributed as dist import torch.nn.functional as F +from .utils import is_share_sp_tp + try: import fused_mix_prec_layer_norm_cuda except: @@ -93,7 +95,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): if ctx.async_grad_allreduce: # Asynchronous all-reduce handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) - # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py grad_weight = total_input.t().matmul(grad_output) @@ -143,7 +145,9 @@ class LinearWithAsyncCommunication(torch.autograd.Function): if ctx.async_grad_allreduce: # Asynchronous all-reduce handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) - # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + _ = torch.zeros(1, device=grad_input.device) + + # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py if _grad_accum_fusion_available and weight.grad is not None: @@ -331,7 +335,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): input_.shape, dtype=input_parallel.dtype, device=input_parallel.device ).contiguous() handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) - # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py if _grad_accum_fusion_available and weight.grad is not None: @@ -646,8 +650,8 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): input_.shape, dtype=input_parallel.dtype, device=input_parallel.device ).contiguous() handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) - # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have - # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py + # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # all-reduce scheduled first and have GPU resources allocated grad_weight = total_input.t().matmul(grad_output) grad_bias = grad_output.sum(dim=0) if use_bias else None @@ -721,16 +725,20 @@ class _ReduceForward(torch.autograd.Function): Args: input_: input matrix. - parallel_mode: parallel mode. + process_group: communication group. + """ @staticmethod - def forward(ctx, input_, process_group): + def forward(ctx, input_, process_group, grad_scale=None): + ctx.grad_scale = grad_scale return _reduce(input_, process_group) @staticmethod def backward(ctx, grad_output): - return grad_output, None + if ctx.grad_scale is not None: + grad_output = grad_output * ctx.grad_scale + return grad_output, None, None class _ReduceBackward(torch.autograd.Function): @@ -979,8 +987,8 @@ def split_forward_gather_backward(input_, dim, process_group, grad_scale=None): return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale) -def reduce_forward(input_, process_group): - return _ReduceForward.apply(input_, process_group) +def reduce_forward(input_, process_group, grad_scale=None): + return _ReduceForward.apply(input_, process_group, grad_scale) def reduce_backward(input_, process_group): @@ -989,3 +997,13 @@ def reduce_backward(input_, process_group): def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) + + +def gather_sp_output(hidden_states, sp_group, sp_mode): + """ + Gather the output of the last layer for cross entropy computation + """ + # Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group) + scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group) + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=scale) + return hidden_states diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 5872c6485..6dab17ec0 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -2,7 +2,10 @@ from enum import Enum from typing import Callable, Dict, Optional, Tuple import torch +import torch.distributed +import torch.distributed as dist import torch.nn.functional as F +from einops import rearrange from colossalai.kernel.kernel_loader import ( FlashAttentionForFloatAndCustomMaskLoader, @@ -10,12 +13,18 @@ from colossalai.kernel.kernel_loader import ( FlashAttentionWithCustomMaskLoader, KernelLoader, ) +from colossalai.logging import get_dist_logger + +from .utils import RingComm, get_half_index, split_varlen_zigzag __all__ = [ "AttnMaskType", "ColoAttention", ] +_flash_attn_forward = _flash_attn_backward = None +_unpad_input = _pad_input = None + class AttnMaskType(Enum): CUSTOM = 0 @@ -38,20 +47,32 @@ def invert_mask(mask: torch.Tensor) -> torch.Tensor: # adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py -def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor]: +def get_pad_info( + padding_mask: torch.Tensor, invert: Optional[bool] = False, return_indices: Optional[bool] = True +) -> Tuple[int, torch.Tensor, torch.Tensor]: """Get padding information from padding mask. Args: - padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, S] + padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, Skv] + invert (Optional[bool], optional): Whether to reverse the padding mask. + return_indices (Optional[bool], optional): Whether to return the indices of non-masked tokens. Returns: - Tuple[int, torch.Tensor, torch.Tensor]: Tuple of (max_seq_len, cu_seqlens, indices) + max_seqlen_in_batch (int): Maximum sequence length in the batch. + cu_seqlens (torch.Tensor): Shape [B+1]. Cumulative sequence lengths of the sequences in the batch. + indices (torch.Tensor): Shape [total_nonzero]. The indices of non-masked tokens from the flattened input sequence. """ + if invert: + padding_mask = padding_mask.logical_not() seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() + if return_indices: + indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return max_seqlen_in_batch, cu_seqlens, indices + if return_indices: + return max_seqlen_in_batch, cu_seqlens, indices + return max_seqlen_in_batch, cu_seqlens class ColoAttention: @@ -107,6 +128,7 @@ class ColoAttention: q_padding_mask: Optional[torch.Tensor] = None, kv_padding_mask: Optional[torch.Tensor] = None, is_causal: bool = False, + invert: bool = True, ) -> Dict[str, torch.Tensor]: """Return a dictionary of keyword arguments for attention function. It supports 4 mask type. 1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves. @@ -124,7 +146,7 @@ class ColoAttention: The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token. If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None. is_causal (bool, optional): Whether to use causal attention mask. Defaults to False. - + invert_mask (bool, optional): Whether to invert the mask. Defaults to True. Returns: Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function. """ @@ -154,7 +176,7 @@ class ColoAttention: assert kv_padding_mask.shape == ( b, s_kv, - ), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})" + ), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})" attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) outputs.update( { @@ -172,7 +194,8 @@ class ColoAttention: attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) else: outputs["attention_mask_type"] = AttnMaskType.PADDED - attention_mask = invert_mask(attention_mask).unsqueeze(1) + if invert: + attention_mask = invert_mask(attention_mask).unsqueeze(1) outputs["attention_mask"] = attention_mask return outputs @@ -191,6 +214,7 @@ class ColoAttention: kv_indices: Optional[torch.Tensor] = None, dropout_p: float = 0.0, scale: Optional[float] = None, + **kwargs, ) -> torch.Tensor: """Flash Attention function. It supports 4 mask type. 1. custom mask: recv attention_mask @@ -199,9 +223,9 @@ class ColoAttention: 4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices Args: - q (torch.Tensor): Query tensor. Shape should be [B, N, Sq, D] - k (torch.Tensor): Key tensor. Shape should be [B, N, Skv, D] - v (torch.Tensor): Value tensor. Shape should be [B, N, Skv, D] + q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D] + k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Skv, D] + v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Skv, D] attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None. attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM. cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths @@ -218,7 +242,7 @@ class ColoAttention: scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None. Returns: - torch.Tensor: Output tensor. Shape should be [B, N, Sq, D] + torch.Tensor: Output tensor. Shape should be [B, nHeads, Sq, D] """ # known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan # this case is usaul when padding mask is used and self attention is performed @@ -252,6 +276,7 @@ class ColoAttention: else: # if attention_mask is None, attention_mask_type should be the default value assert attention_mask_type == AttnMaskType.CUSTOM + # kernel dispatch mask_type = attention_mask_type if attention_mask is not None else None attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type) @@ -274,3 +299,858 @@ class ColoAttention: q_indices=q_indices, kv_indices=kv_indices, ) + + +def _load_varlen_helpers(): + """Helper to load functions for padding and unpadding packed sequences. + Use only when flash attn is installed + """ + global _pad_input, _unpad_input + # Flash attn claims this is more efficient than torch's bool indexing due to avoiding + # broadcast + if _pad_input is None or _unpad_input is None: + try: + from flash_attn.bert_padding import index_first_axis, pad_input + + def unpad_input(hidden_states: torch.Tensor, indices: torch.Tensor): + return index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices) + + _pad_input = pad_input + _unpad_input = unpad_input + except ImportError as e: + raise RuntimeError( + f"Flash Attention is not installed. You can install it via 'pip install flash-attn --no-build-isolation'" + ) from e + + +def _load_flash_attn(): + """A light-weight loader to check whether flash-attn is installed. + Can't use ColoAttention._dispatch_kernel because we mutate the backward pass + """ + global _flash_attn_forward, _flash_attn_backward + if _flash_attn_forward is None or _flash_attn_backward is None: + try: + from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward + from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward + except ImportError as e: + raise RuntimeError( + f"Flash Attention is not installed. You can install it via 'pip install flash-attn --no-build-isolation'" + ) from e + + _load_varlen_helpers() + + +# NOTE: This can cause spawned processes to hang on exit +# with python 3.9 +@torch.compile() +def _rescale_out_lse(out, block_out, lse, block_lse): + """ + Compute the new attention denominator: + exp(lse) + exp(block_lse) = exp(max_scale) * (exp(min_scale - max_scale) + 1) + Args: + out: (T, H, D) + block_out: (T, H, D) + lse: (H, T, 1) + block_lse: (H, T, 1) + """ + + # min_scale = torch.min(lse, block_lse) + # max_scale = torch.max(lse, block_lse) + # new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + + # NOTE: directly assigning to .data here is buggy + # probably due to casting dtypes/strides + new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + + new_block_lse = torch.exp(block_lse - new_lse) + out = (torch.exp(lse - new_lse) * out + new_block_lse * block_out).to(out) + lse = new_lse + + # Equivalent to the above + # See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + # out = (out - F.sigmoid(block_lse - lse) * (out - block_out)) + # lse = (lse - F.logsigmoid(lse - block_lse)) + return out, lse + + +class RingAttention(torch.autograd.Function): + """Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context` + (https://arxiv.org/abs/2310.01889). + For load-balancing we adopted the "zigzag" attention scheme from https://github.com/zhuzilin/ring-flash-attention/tree/main + For portable integration with more models, we don't follow the spirit of "block-wise FNN" in the original paper, + which requires fusing FFN with the Flash Attention kernel/function (see https://arxiv.org/pdf/2305.19370; + implemented in Jax and not optimized). + We adopt the double ring topology from LoongTrain (https://arxiv.org/pdf/2406.18485) to fully utilize available + NICs on each node, by computing attention within a inner ring first and then sending all KVs to the next + ring at once. + """ + + # Globle cache to avoid recomputation for same-lengthed sequences + CU_SEQLENS: torch.Tensor = None # [B+1] + TOTAL_SEQLEN: int = None + HALF_INDICES: Tuple = None + SUPPORTED_MASK_TYPES = (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL) + ATTN_DONE: torch.cuda.Event = None + SP_STREAM: torch.cuda.Stream = None + SP_GROUP: dist.ProcessGroup = None + # duplicate process group for concurrent NCCL streams + # while both PyTorch and NCCL warns(https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7) + # against this, in practice it seems to work fine. + INNER_RING_GROUP: dist.ProcessGroup = None + INNER_RING_GROUP_COPY: dist.ProcessGroup = None + INTER_RING_GROUP: dist.ProcessGroup = None + INTER_RING_GROUP_COPY: dist.ProcessGroup = None + + @staticmethod + def get_double_ring_groups(sp_group, inner_ring_size=None): + """ + Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size + shouldn't be larger than the number of NICs on each node. + Args: + sp_group (dist.ProcessGroup): Process group for sequence parallelism + inner_ring_size (Optional[int], optional): Inner ring size. Defaults to None. + Returns: + Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. + """ + sp_size = dist.get_world_size(sp_group) + sp_rank = dist.get_rank(sp_group) + + if inner_ring_size is None: + if torch.cuda.device_count() >= dist.get_world_size(): + # single node, no need to consider NICs + return sp_group, sp_group + if sp_size <= 4: + inner_ring_size = min(2, sp_size) + else: + inner_ring_size = min(4, sp_size) + else: + assert ( + inner_ring_size <= sp_size and sp_size % inner_ring_size == 0 + ), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}" + + if inner_ring_size == sp_size: + return sp_group, sp_group + assert ( + sp_size % inner_ring_size == 0 + ), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}" + + logger = get_dist_logger() + logger.info( + f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!", + ranks=[0], + ) + num_rings = sp_size // inner_ring_size + inner_ring_group = None + inter_ring_group = None + + # Create inner ring groups + for i in range(inner_ring_size): + ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) + group = dist.new_group(ranks) + if sp_rank in ranks: + inner_ring_group = group + + # Create inter ring groups + for i in range(num_rings): + ranks = list(range(i, sp_size, num_rings)) + group = dist.new_group(ranks) + if sp_rank in ranks: + inter_ring_group = group + + return inner_ring_group, inter_ring_group + + @staticmethod + def attention( + q, # (B, H, Sq, D) + k, + v, + sp_group, + attention_mask_type, + cu_seqlens=None, + max_seqlen=None, + valid_indices=None, + dropout_p=0.0, + softmax_scale=None, + deterministic=False, + return_softmax=False, + inner_ring_size=None, + **kwargs, + ): + """ + Ring Attention forward pass supporting variable-length sequences. When using varlen mode, + each sequence in the batch should have length divisible by sp_size * 2. + + Args: + q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D] + k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D] + v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D] + sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism + sp_tream (torch.cuda.Stream): An different stream for output correction. + cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths + of the sequences in the batch, used to index into q. + Shape should be [B+1]. + max_seqlen (Optional[int], optional): Maximum query sequence length in the batch. + valid_indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from get_pad_info. + Shape should be [t]. + dropout_p (float, optional): Dropout probability. Defaults to 0.0. + softmax_scale (Optional[float], optional): Scaling factor applied prior to softmax. + deterministic (bool, optional): Whether to force deterministic backward pass. See https://github.com/Dao-AILab/flash-attention/issues/349 + return_softmax (bool, optional): Whether to return the softmax denominator (logsumexp). + inner_ring_size (Optional[int], optional): Inner ring size of the 2D ring. By default use a heuristic to decide. + + Returns: + out: Output tensor of shape [B, nHeads, Sq, D] or [T, nHeads, D] if pad_output is False. + softmax_lse: (if return_softmax is True) Softmax denominator (logsumexp). + Shape should be [total_q_seqlen, nHeads] + """ + # Check input args + _load_flash_attn() + if RingAttention.ATTN_DONE is None: + RingAttention.ATTN_DONE = torch.cuda.Event() + if RingAttention.SP_STREAM is None: + RingAttention.SP_STREAM = torch.cuda.Stream() + + assert ( + q.shape[2] == k.shape[2] + ), "Q, K and V having different sequence lengths (inference or cross-attn)\ + is not supported yet in training." + assert ( + attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES + ), f"Mask type {attention_mask_type} is not supported yet." + + clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg)) + + if RingAttention.SP_GROUP is not sp_group: + RingAttention.SP_GROUP = sp_group + inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size) + RingAttention.INNER_RING_GROUP = inner_ring_group + RingAttention.INTER_RING_GROUP = inter_ring_group + else: + inner_ring_group = RingAttention.INNER_RING_GROUP + inter_ring_group = RingAttention.INTER_RING_GROUP + + # (B, H, Sq, D) -> (B, Sq, H, D) + q, k, v = [x.transpose(1, 2).contiguous() for x in (q, k, v)] + pad_output = q.dim() == 4 + + # Get sequence length info for varlen forward + if attention_mask_type == AttnMaskType.CAUSAL: + # All sequences share the same length + b, sq, h, d = q.shape + max_seqlen = sq + # Cache to avoid recreation for a single sequence + if sq * b == RingAttention.TOTAL_SEQLEN: + cu_seqlens = RingAttention.CU_SEQLENS + else: + cu_seqlens = torch.arange(0, b * sq + 1, sq, device=q.device, dtype=torch.int32) + RingAttention.TOTAL_SEQLEN = b * sq + + # "Packed" mode where sequences of different lengths are packed into [total_q_seqlen, H, D] + elif attention_mask_type == AttnMaskType.PADDED_CAUSAL: + assert ( + cu_seqlens is not None and max_seqlen is not None and valid_indices is not None + ), "Packed mode requires pre-computed cu_seqlens and max_seq_len." + if pad_output: + b, sq, h, d = q.shape + q, k, v = [_unpad_input(x, valid_indices) for x in (q, k, v)] + + out, softmax_lse = RingAttention.apply( + q, + k, + v, + sp_group, + RingAttention.SP_STREAM, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + deterministic, + return_softmax, + attention_mask_type == AttnMaskType.PADDED_CAUSAL, + inner_ring_group, + inter_ring_group, + ) + + if attention_mask_type == AttnMaskType.PADDED_CAUSAL: + if pad_output: + out = _pad_input(out, valid_indices, b, sq) # (T, ...) -> (B, Sq, ...) + out = out.transpose(1, 2) # (B, Sq, H, D) -> (B, H, Sq, D) + else: + out = out.transpose(1, 2) + + if return_softmax: + return out, softmax_lse + return out + + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sp_group: dist.ProcessGroup, + sp_stream: torch.cuda.Stream, + cu_seqlens: torch.Tensor, + max_seqlen: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + deterministic: Optional[bool] = False, + return_softmax: Optional[bool] = False, + is_packed: Optional[bool] = False, + inner_ring_group: Optional[dist.ProcessGroup] = None, + inter_ring_group: Optional[dist.ProcessGroup] = None, + ): + + cu_seqlens_q = cu_seqlens_kv = cu_seqlens + max_seqlen_q = max_seqlen_kv = max_seqlen + cu_seqlens_half = cu_seqlens // 2 + max_seqlen_half = max_seqlen // 2 + + misc_kwargs = { + "window_size": (-1, -1), + "alibi_slopes": None, + "softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale, + "dropout_p": dropout_p, + "block_table": None, + "softcap": 0.0, + "return_softmax": False, + } + + if ( + RingAttention.HALF_INDICES is not None + and cu_seqlens.shape == RingAttention.CU_SEQLENS.shape + and (cu_seqlens == RingAttention.CU_SEQLENS).all() + ): + half_idx_front, half_idx_back = RingAttention.HALF_INDICES + else: + half_idx_front = get_half_index(cu_seqlens, front=True) + half_idx_back = get_half_index(cu_seqlens, front=False) + RingAttention.HALF_INDICES = (half_idx_front, half_idx_back) + RingAttention.CU_SEQLENS = cu_seqlens + + if is_packed: + t, h, d = q.shape + else: + b, sq, h, d = q.shape + t = b * sq + # Be careful about GQA/MQA in reshape + q, k, v = [x.view(t, *x.shape[-2:]) for x in (q, k, v)] + + if inner_ring_group is None or inter_ring_group is None: + # Use one ring if not specified + inner_ring_group = inter_ring_group = sp_group + + sp_size = dist.get_world_size(sp_group) + sp_rank = dist.get_rank(sp_group) + # Attempt to achieve concurrent comm in the two-stream forward + local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)] + inter_ring_comm = RingComm(inter_ring_group) + local_sp_size = dist.get_world_size(inner_ring_group) + local_sp_rank = dist.get_rank(inner_ring_group) + inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0 + num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1 + + # Non-contiguous indexing copies to a new contiguous tensor, + # so only do it once + if sp_rank != sp_size - 1: + q1 = q[half_idx_back] + + # Pre-allocate double buffer for overlapping and receiving next step's inputs + kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D) + kv_buffers.append(torch.empty_like(kv_buffers[0])) + + # outputs + out = None + block_out = [None, None] + softmax_lse = [None, None] + block_softmax_lse = [None, None] # log sum exp, the denominator of softmax in attention + rng_states = [None for _ in range(sp_size)] + sp_streams = [torch.cuda.current_stream(), sp_stream] + + def _forward(q, k, v, causal): + ( + _, + _, + _, + _, + out, + softmax_lse, + _, + rng_state, + ) = _flash_attn_forward( + q, + k, + v, + cu_seqlens_q if q.shape[0] == t else cu_seqlens_half, + cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half, + max_seqlen_q if q.shape[0] == t else max_seqlen_half, + max_seqlen_kv if k.shape[0] == t else max_seqlen_half, + causal=causal, + **misc_kwargs, + ) + return out, softmax_lse, rng_state + + def _local_ring_forward(): + # (Hopefully) overlap output correction with next flash attn + for i in range(local_sp_size): + with torch.cuda.stream(sp_streams[i % 2]): + # Wait for current kv from prev rank + # NOTE: waiting outside the current stream will NOT correctly synchronize. + if i > 0: + local_kv_comms[(i + 1) % 2].wait() + + # Avoid overwriting attn input when it shares mem with buffer + if not RingAttention.ATTN_DONE.query(): + kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) + if i < local_sp_size - 1: + local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) + + if i == 0: + # Compute with local KV; no mask + kv_block = kv_buffers[0] + q_block = q + (block_out[i % 2], block_softmax_lse[i % 2], rng_states[i]) = _forward( # (T, H, D) # (H, T) + q_block, kv_block[0], kv_block[1], causal=True + ) + elif i <= local_sp_rank: + # Received the "surrounding" kv chunks + # Drop the second half of received kv + # (2, t // 2, H, D) + kv_block = kv_buffers[i % 2][:, half_idx_front] + q_block = q + ( + block_out[i % 2], # (T, H, D) + block_softmax_lse[i % 2], # (H, T) + rng_states[i], + ) = _forward(q_block, kv_block[0], kv_block[1], causal=False) + else: + # Received the inner kv chunks + # Drop the first half of q + kv_block = kv_buffers[i % 2] + q_block = q1 + ( + block_out[i % 2], # (T, H, D) + block_softmax_lse[i % 2], # (H, T) + rng_states[i], + ) = _forward(q_block, kv_block[0], kv_block[1], causal=False) + RingAttention.ATTN_DONE.record() + + block_softmax_lse[i % 2] = ( + block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() + ) # (H, T) -> (T, H, 1) + assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1] + # Output and log sum exp correction. Ideally overlap this with the next flash attn kernel. + # In reality this always finishes before next flash attn; no need for extra sync. + if i == 0: + out = block_out[0] + softmax_lse = block_softmax_lse[0] + elif i <= local_sp_rank: + out, softmax_lse = _rescale_out_lse( + out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2] + ) + else: + out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse( + out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2] + ) + + torch.cuda.current_stream().wait_stream(sp_stream) + return out, softmax_lse + + def _other_ring_forward(ring_num_idx, out, softmax_lse): + # Loop through the inner ring after receiving + # all new KVs from the previous inner ring + for i in range(local_sp_size): + with torch.cuda.stream(sp_streams[i % 2]): + if not RingAttention.ATTN_DONE.query(): + kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) + if i < local_sp_size - 1: + local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) + + # Send & recv KV + if i > 0: + local_kv_comms[(i + 1) % 2].wait() + + if ring_num_idx > inter_ring_rank: + kv_block = kv_buffers[i % 2] + ( + block_out[i % 2], + block_softmax_lse[i % 2], + rng_states[i + local_sp_size * ring_num_idx], + ) = _forward(q1, kv_block[0], kv_block[1], causal=False) + RingAttention.ATTN_DONE.record() + block_softmax_lse[i % 2] = ( + block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() + ) + out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse( + out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2] + ) + else: + kv_block = kv_buffers[i % 2][:, half_idx_front] + ( + block_out[i % 2], + block_softmax_lse[i % 2], + rng_states[i + local_sp_size * ring_num_idx], + ) = _forward(q, kv_block[0], kv_block[1], causal=False) + RingAttention.ATTN_DONE.record() + block_softmax_lse[i % 2] = ( + block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() + ) + out, softmax_lse = _rescale_out_lse( + out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2] + ) + + torch.cuda.current_stream().wait_stream(sp_stream) + return out, softmax_lse + + # Send and recv KV between rings at once to maximize NIC util. + inter_ring_kv = None + for ring_num_idx in range(num_rings): + if ring_num_idx > 0: + inter_ring_comm.wait() + # Reset indices + kv_buffers[0] = inter_ring_kv + + if ring_num_idx < num_rings - 1: + if ring_num_idx == 0: + to_send = kv_buffers[0] + else: + # The last received KV + to_send = kv_buffers[(local_sp_size - 1) % 2] + inter_ring_kv = inter_ring_comm.send_recv(to_send) + + if ring_num_idx == 0: + out, softmax_lse = _local_ring_forward() + else: + out, softmax_lse = _other_ring_forward(ring_num_idx, out, softmax_lse) + + out = out.to(q.dtype) + if not is_packed: + out = out.view(b, sq, h, d) + q, k, v = [x.view(b, sq, *x.shape[-2:]) for x in (q, k, v)] # (T, H, D) -> (B, Sq, H, D) + softmax_lse = softmax_lse.squeeze(-1) + + ctx.sp_group = sp_group + ctx.max_seqlen_q = ctx.max_seqlen_kv = max_seqlen + misc_kwargs["deterministic"] = deterministic + del misc_kwargs["return_softmax"] + ctx.misc_kwargs = misc_kwargs + ctx.is_packed = is_packed + + ctx.kv_group = inner_ring_group + ctx.inter_kv_group = inter_ring_group + + ctx.save_for_backward( + q, + k, + v, + out, + softmax_lse.transpose(0, 1).contiguous(), # (T, H) -> (H, T) + cu_seqlens_q, + cu_seqlens_kv, + half_idx_front, + half_idx_back, + *rng_states, + ) + + if return_softmax: + return out, softmax_lse + return out, None + + def backward(ctx, dout, _): + """ + During backward, we accumulate q grads on each rank locally, but iterate kv and their grads + over all ranks for accumulation. + """ + (q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9] + rng_states = ctx.saved_tensors[9:] + + is_packed = ctx.is_packed + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_kv = ctx.max_seqlen_kv + cu_seqlens_half = cu_seqlens_q // 2 + max_seqlen_half = max_seqlen_q // 2 + misc_kwargs = ctx.misc_kwargs + del misc_kwargs["block_table"] + + assert ( + out.shape == dout.shape == q.shape + ), f"out {out.shape} and dout {dout.shape} should have the same shape ({q.shape})." + + if is_packed: + t, h, d = q.shape + else: + b, sq, h, d = q.shape + t = b * sq + q, k, v, out, dout = [x.view(t, *x.shape[-2:]) for x in (q, k, v, out, dout)] + + # Sequence parallel args + sp_group = ctx.sp_group + local_kv_group = ctx.kv_group + inter_kv_group = ctx.inter_kv_group + + local_sp_rank = dist.get_rank(sp_group) + sp_size = dist.get_world_size(sp_group) + # Using separate streams (pg) for concurrent kv and dkv comm may + # cause NCCL "software caused connection abort" here... + local_kv_comm = RingComm(local_kv_group) + local_dkv_comm = RingComm(local_kv_group) + inter_kv_comm = RingComm(inter_kv_group) + inter_dkv_comm = RingComm(inter_kv_group) + local_sp_size = dist.get_world_size(local_kv_group) + local_sp_rank = dist.get_rank(local_kv_group) + + if dist.get_world_size(inter_kv_group) != sp_size: + num_rings = dist.get_world_size(inter_kv_group) + inter_ring_rank = dist.get_rank(inter_kv_group) + else: + num_rings = 1 + inter_ring_rank = 0 + + if local_sp_rank != sp_size - 1: + softmax_lse1 = softmax_lse[:, half_idx_back] + dout = dout.contiguous() + + # Double comm buffers for sending and receiving kv + kv_buffers = [torch.stack((k, v))] # (2, T, H, D) + kv_buffers.append(torch.empty_like(kv_buffers[0])) + + dq = None # (T, H, D) + # Intermediate outputs + dq_block = torch.empty_like(q) # (T, H, D) + dk_block = torch.empty_like(k) # (T, H, D) + dv_block = torch.empty_like(v) # (T, H, D) + dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D) + del k, v + + def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal): + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q if dq.shape[0] == t else cu_seqlens_half, + cu_seqlens_kv if dk.shape[0] == t else cu_seqlens_half, + max_seqlen_q if dq.shape[0] == t else max_seqlen_half, + max_seqlen_kv if dk.shape[0] == t else max_seqlen_half, + causal=causal, + rng_state=rng_state, + **misc_kwargs, + ) + + # NOTE: We avoid using two streams due to doubled buffers + # and that backward is more communication intensive. + def _local_ring_backward(): + for i in range(local_sp_size): + if i > 0: + local_kv_comm.wait() + + if i < local_sp_size - 1: + # Send kv to next rank for backward + local_kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) + + if i == 0: + # Backward with local kv + k_, v_ = kv_buffers[i % 2] + q_, dout_, out_ = q, dout, out + dq_, dk_, dv_ = dq_block, dk_block, dv_block + _backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_states[i], causal=True) + + elif i <= local_sp_rank: + # Drop the second half of kv + # (T, H, D) -> (T // 2, H, D) + k_, v_ = [x[half_idx_front] for x in kv_buffers[i % 2]] + dk_, dv_ = [x[: t // 2] for x in (dk_block, dv_block)] + dq_, q_, out_, dout_ = (dq_block, q, out, dout) + _backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_states[i], causal=False) + + else: + # Drop the first half of q + k_, v_ = kv_buffers[i % 2] + dk_, dv_ = dk_block, dv_block + q_, out_, dout_ = [x[half_idx_back] for x in (q, out, dout)] + dq_ = dq_block[: t // 2] + _backward(dout_, q_, k_, v_, out_, softmax_lse1, dq_, dk_, dv_, rng_states[i], causal=False) + + # Accumulate grads + if i == 0: + dq = dq_block.float() + dkv_buffers[i % 2][0] = dk_block.float() + dkv_buffers[i % 2][1] = dv_block.float() + else: + # Accumulate local dq + if i <= local_sp_rank: + dq += dq_ # (T, H, D) + else: + dq[half_idx_back] += dq_ + + # Wait for mobile kv grad accumulators + local_dkv_comm.wait() + + if i <= local_sp_rank: + # q blocks "surrounded" by kv blocks + dkv_buffers[i % 2][0][half_idx_front] += dk_ + dkv_buffers[i % 2][1][half_idx_front] += dv_ + else: + # q blocks "surrounding" kv blocks + dkv_buffers[i % 2][0] += dk_ + dkv_buffers[i % 2][1] += dv_ + local_dkv_comm.send_recv(send_tensor=dkv_buffers[i % 2], recv_tensor=dkv_buffers[(i + 1) % 2]) + + local_dkv_comm.wait() + dkv_recv = dkv_buffers[local_sp_size % 2] + dkv_send = dkv_buffers[(local_sp_size - 1) % 2] + return dq, dkv_recv, dkv_send + + def _other_ring_backward(ring_num_idx, dq): + if ring_num_idx > inter_ring_rank: + # Indexing is expensive + q_, out_, dout_ = [x[half_idx_back] for x in (q, out, dout)] + else: + q_, out_, dout_ = (q, out, dout) + + for i in range(local_sp_size): + if i > 0: + local_kv_comm.wait() + + if i < local_sp_size - 1: + local_kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) + + rng_state = rng_states[i + local_sp_size * ring_num_idx] + if ring_num_idx > inter_ring_rank: + k_, v_ = kv_buffers[i % 2] + dk_, dv_ = dk_block, dv_block + dq_ = dq_block[: t // 2] + _backward(dout_, q_, k_, v_, out_, softmax_lse1, dq_, dk_, dv_, rng_state, causal=False) + + dq[half_idx_back] += dq_ + if i > 0: + local_dkv_comm.wait() + else: + inter_dkv_comm.wait() + + dkv_buffers[i % 2][0] += dk_ + dkv_buffers[i % 2][1] += dv_ + else: + k_, v_ = [x[half_idx_front] for x in kv_buffers[i % 2]] + dk_, dv_ = [x[: t // 2] for x in (dk_block, dv_block)] + dq_ = dq_block + _backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_state, causal=False) + + dq += dq_ + if i > 0: + local_dkv_comm.wait() + else: + inter_dkv_comm.wait() + + dkv_buffers[i % 2][0][half_idx_front] += dk_ + dkv_buffers[i % 2][1][half_idx_front] += dv_ + + local_dkv_comm.send_recv(send_tensor=dkv_buffers[i % 2], recv_tensor=dkv_buffers[(i + 1) % 2]) + + local_dkv_comm.wait() + dkv_recv = dkv_buffers[local_sp_size % 2] + dkv_send = dkv_buffers[(local_sp_size - 1) % 2] + return dq, dkv_recv, dkv_send + + inter_ring_kv = None + for ring_num_idx in range(num_rings): + if ring_num_idx > 0: + inter_kv_comm.wait() + kv_buffers[0] = inter_ring_kv + + if ring_num_idx < num_rings - 1: + # Re-allocate a buffer in each inter-ring step + inter_ring_kv = inter_kv_comm.send_recv(kv_buffers[0]) + + if ring_num_idx == 0: + dq, dkv_recv, dkv_send = _local_ring_backward() + else: + dq, dkv_recv, dkv_send = _other_ring_backward(ring_num_idx, dq) + + if num_rings > 1: + # Reuse the local buffers + inter_dkv_comm.send_recv(send_tensor=dkv_recv, recv_tensor=dkv_send) + # Reset indices + dkv_buffers[0] = dkv_send + dkv_buffers[1] = dkv_recv + if ring_num_idx == num_rings - 1: + inter_dkv_comm.wait() + dkv_recv = dkv_buffers[0] + + dq, dk, dv = [x.to(q.dtype) for x in (dq, *dkv_recv)] + if not is_packed: + dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)] + + return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None) + + @staticmethod + def prepare_varlen_batch( + attention_mask: torch.Tensor, + sp_group: dist.ProcessGroup, + inputs_embeds: torch.Tensor = None, + position_ids: Optional[torch.Tensor] = None, + is_label: bool = False, + is_2d: bool = True, + ): + """ + Preprocess a batch of padded sequence by splitting input sequence by sp_size + sequence-wise and packing them into one sequence. Updates the mask info accordingly. + Args: + attention_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked. + sp_group (dist.ProcessGroup): Process group for sequence parallelism + inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...] + position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None. + is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first + token of each sequence. + is_2d (bool, optional): Whether to return 2D outputs padded to max_seqlen // sp_size or flatten + the batch dim to a packed 1d sequence. Contingent on model forward shape definitions. + + Returns: + inputs_embeds: Packed input embeddings of shape [B, Sq // sp_size, ...]. + mask_info: A dictionary of mask info. + position_ids: Packed position ids of shape [..., Sq // sp_size]. + + """ + _load_varlen_helpers() + sp_size = dist.get_world_size(group=sp_group) + sp_rank = dist.get_rank(group=sp_group) + mask_info = {} + mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(attention_mask, return_indices=False) + + # Unpad, split seq-wise, then pad back to (B, max_seqlen // sp_size) + # Split mask to compute local nonzero position indices + # (B, Sq) -> (B, max_seqlen // sp_size) + attention_mask = attention_mask[:, : mask_info["max_seqlen"]] + if inputs_embeds is not None: + inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]] + inputs_embeds = split_varlen_zigzag( + inputs_embeds, + mask_info["cu_seqlens"], + sp_group, + mask_info["max_seqlen"], + is_2d=is_2d, + is_label=is_label, + ) + attention_mask = split_varlen_zigzag( + attention_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_2d=is_2d + ) + + if position_ids is not None: + indices = torch.tensor([sp_rank, 2 * sp_size - sp_rank - 1], device=inputs_embeds.device) + position_ids = ( + position_ids[..., : mask_info["max_seqlen"]] # unpad + .view(-1, sp_size * 2, mask_info["max_seqlen"] // (sp_size * 2)) + .index_select(-2, indices) + .view(-1, mask_info["max_seqlen"] // sp_size) + ) + + mask_info["max_seqlen"] //= sp_size + mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + mask_info["cu_seqlens"] //= sp_size + mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL + return inputs_embeds, mask_info, position_ids diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 37c754241..020e793af 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -200,9 +200,7 @@ class Linear1D_Col(ParallelModule): # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - if self.seq_parallel_mode is None: - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) - elif self.seq_parallel_mode == "split_gather": + if self.seq_parallel_mode == "split_gather": input_parallel = gather_forward_reducescatter_backward( input_parallel, self.process_group, self.seq_parallel_dim ) @@ -211,6 +209,8 @@ class Linear1D_Col(ParallelModule): output_parallel = linear_gather_forward_reducescatter_backward( input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True ) + else: + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) if self.gather_output: # All-gather across the partitions. @@ -416,10 +416,7 @@ class Linear1D_Row(ParallelModule): handle.wait() output = torch.cat(output_parallel_list, dim=-1) else: - if self.seq_parallel_mode is None: - output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) - output = reduce_forward(output_parallel, self.process_group) - elif self.seq_parallel_mode == "split_gather": + if self.seq_parallel_mode == "split_gather": output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output = reducescatter_forward_gather_backward( output_parallel, self.process_group, self.seq_parallel_dim @@ -432,6 +429,9 @@ class Linear1D_Row(ParallelModule): dim=self.seq_parallel_dim, ring=True, ) + else: + output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) + output = reduce_forward(output_parallel, self.process_group) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index cea2da03f..12df824d1 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -4,10 +4,15 @@ from torch.autograd import Function from torch.distributed import ProcessGroup from torch.nn import CrossEntropyLoss +from colossalai.shardformer.layer._operation import reduce_forward from colossalai.shardformer.shard import ShardConfig +from .utils import is_share_sp_tp + __all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"] +_IGNORE_IDX = -100 + class DistCrossEntropy(Function): r""" @@ -26,11 +31,12 @@ class DistCrossEntropy(Function): process_group: ProcessGroup, vocab_size: int, dtype=torch.float32, + mode="mean", ): r""" Calculate the cross entropy loss before gather, the origin loss function is as follows: loss = -log(exp(x[class])/sum(exp(x[i])) - and can be rewrite as: + and can be rewriten as: loss = log(sum(exp(x[i])) - x[class] To avoid the `nan` of log(sum(exp(x[i]))), we minus the max of x[i] @@ -44,12 +50,10 @@ class DistCrossEntropy(Function): Returns: :class:`torch.Tensor`: The cross entropy loss """ + assert mode in ["mean", "sum"] # get the max logits_max = torch.max(vocab_logits, dim=-1)[0] - dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group) - - # minus the max to avoid the result of sum of exp is too large and the log is nan - vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) + handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True) # mask the target in the local device rank = dist.get_rank(group=process_group) @@ -70,24 +74,25 @@ class DistCrossEntropy(Function): mask = (target < down_threshold) | (target >= up_threshold) masked_target = target.clone() - down_threshold masked_target[mask] = 0 + masked_target_1d = masked_target.view(-1).contiguous() + # minus the max to avoid the result of sum of exp is too large and the log is nan + handle.wait() + vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) # reshape the logits and target # reshape the vocab_logits to [bath_size * seq_len, vocab_size] # reshape the labels to [bath_size * seq_len] self_vocab_size = vocab_logits.size()[-1] logits_2d = vocab_logits.view(-1, self_vocab_size) - masked_target_1d = masked_target.view(-1) # extract the x[class] and set the x[other device] to zero - pred_logits_1d = logits_2d[ - torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device), masked_target_1d - ] - pred_logits_1d = pred_logits_1d.clone().contiguous() + idx = torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device) + pred_logits_1d = logits_2d[idx, masked_target_1d].contiguous() pred_logits = pred_logits_1d.view_as(target) pred_logits[mask] = 0.0 - # allreduce the get all x(i,y) - dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group) + # all-reduce to get full x[i, y] + handle = dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group, async_op=True) exp_logits = vocab_logits torch.exp(vocab_logits, out=exp_logits) sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) @@ -95,23 +100,29 @@ class DistCrossEntropy(Function): # calculate the loss # loss = log(sum(exp(x[i]))) - x[class] + handle.wait() loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits) - num_non_zero = torch.sum(loss != 0.0) - ctx.inv_num_non_zero = 1.0 / num_non_zero - loss = torch.sum(loss).div_(num_non_zero) + if mode == "mean": + num_non_zero = torch.sum(loss != 0.0) + ctx.inv_num_non_zero = 1.0 / num_non_zero + loss = torch.sum(loss).div_(num_non_zero) + else: + loss = torch.sum(loss) # calculate the softmax exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype) exp_logits[target == ignore_index] = 0.0 ctx.save_for_backward(exp_logits, mask, masked_target_1d) ctx.dtype = dtype + ctx.mode = mode return loss @staticmethod def backward(ctx, grad_output): # retrieve the saved tensors - grad_output = grad_output * ctx.inv_num_non_zero + if ctx.mode == "mean": + grad_output = grad_output * ctx.inv_num_non_zero exp_logits, mask, masked_target_1d = ctx.saved_tensors # use exp logits as the input grad @@ -123,55 +134,113 @@ class DistCrossEntropy(Function): grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update grad_logits.mul_(grad_output.unsqueeze(dim=-1)) - return grad_logits, None, None, None, None, None + return grad_logits, None, None, None, None, None, None def cross_entropy_1d( vocab_logits: torch.Tensor, labels: torch.Tensor, - ignore_index: int = -100, + ignore_index: int = _IGNORE_IDX, process_group: ProcessGroup = None, vocab_size: int = None, dtype: torch.dtype = None, + mode: str = "mean", ) -> torch.Tensor: - return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype) + return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode) def dist_cross_entropy( - labels: torch.Tensor, - logits: torch.Tensor, + labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] + logits: torch.Tensor, # [B, S, Vocab_size] shard_config: ShardConfig, out_features: int, vocab_size: int, dtype: torch.dtype, + seq_dim: int = 1, ) -> torch.Tensor: """ - Helper to compute cross entropy loss for most shardformer models, - compatible with PP, TP and SP. + Helper to compute cross entropy loss for most shardformer models supporting PP, TP and SP. """ - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - # Cross entropy with all-reduce for TP - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=out_features, - dtype=dtype, - ) - else: - # NOTE if use TP and not parallel_output, the output is gathered. - # see VocabParallelLMHead1D - shift_logits = shift_logits.view(-1, vocab_size) - loss = loss_fct(shift_logits, shift_labels) - - return loss + # Split labels if not gather output + sp_group = shard_config.sequence_parallel_process_group + sp_rank = dist.get_rank(sp_group) + sp_size = shard_config.sequence_parallel_size + sp_mode = shard_config.sequence_parallelism_mode + parallel_output = shard_config.parallel_output + is_tp = shard_config.enable_tensor_parallelism + is_packed = labels.dim() == 2 + if is_packed: + bs, seq_len = labels.shape + else: + # padded sequence + seq_len = labels.shape[-1] + logits = logits.reshape(-1, *logits.shape[2:]) + seq_dim = 0 + + # Shift labels to predict the next token, and remove the tail logit predicting + is_sp = sp_size > 1 and (not is_share_sp_tp(sp_mode)) + split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward + + if sp_mode == "ring_attn": + # For Zigzag Ring Attention, labels should've been split and + # shifted by RingAttention.prepare_varlen_batch() + if sp_rank == 0: + logits = logits[..., :-1, :] + logits = torch.cat([logits, torch.full_like(logits[:, :1, :], _IGNORE_IDX)], dim=seq_dim) + elif is_sp: + # Shift only once: either before splitting or in the last rank without splitting + if split_labels_here or (sp_rank == sp_size - 1): + labels = labels[..., 1:] + if split_labels_here: + labels = labels.split(seq_len // sp_size, dim=-1)[sp_rank] + + if sp_rank == sp_size - 1: + logits = logits[..., :-1, :] + # Pad logits and labels to the same shape across all ranks for TP all_reduce + if is_tp and parallel_output: + # If is packed sequence (label dim is 1), then each seq already has the end label token padded. + # torch.cat is faster than F.pad... + pad_shape = (logits.shape[0], 1, *logits.shape[2:]) if is_packed else (1, *logits.shape[1:]) + padding = torch.full(pad_shape, _IGNORE_IDX, dtype=logits.dtype, device=logits.device) + logits = torch.cat([logits, padding], dim=seq_dim) + pad_shape = (labels.shape[0], 1) if is_packed else (1,) + padding = torch.full(pad_shape, _IGNORE_IDX, dtype=labels.dtype, device=labels.device) + labels = torch.cat([labels, padding], dim=seq_dim) + else: + labels = labels[..., 1:] + logits = logits[..., :-1, :] + labels = labels.contiguous() + logits = logits.contiguous() + num_nonzero = (labels != _IGNORE_IDX).sum() + assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}" + + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=_IGNORE_IDX, reduction="sum") + labels = labels.view(-1) + + if is_tp and parallel_output: + # Cross entropy with all-reduce for TP + new_vocab_size = logits.shape[-1] + logits = logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + logits, + labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=out_features, + dtype=dtype, + mode="sum", + ) + else: + # NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D + logits = logits.view(-1, vocab_size) + loss = loss_fct(logits, labels) + + # Reduce loss instead of gathering logits over seq dim for savings + if split_labels_here or sp_mode == "ring_attn": + # Get the global non-zero count + loss = torch.stack((loss, num_nonzero)) + # Rescale to offset the grad / (DP * SP) in HybridParallelPlugin + loss = reduce_forward(loss, sp_group, grad_scale=sp_size) + loss, num_nonzero = loss[0], loss[1].detach() + loss = (loss / num_nonzero).squeeze() + return loss diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 9c6ced445..c1a73ce05 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import List +from typing import List, Optional, Union import torch import torch.distributed as dist @@ -289,3 +289,199 @@ def create_randomizer_with_offset( Randomizer.increment_index() return Randomizer(seed=base_seed) + + +def split_batch_zigzag( + batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False +) -> Union[torch.Tensor, List[torch.Tensor]]: + """ + Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask + in the causal setting will result in the preceding ranks having much less workload. + We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2). + For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |. + + Args: + batch (List[torch.Tensor] or Tensor): The input tensor(s) to split. + sp_group (ProcessGroup): The process group for sequence parallelism. + seq_dim (int): The sequence dimension to split. + is_label (bool): If True, mask and shift the tensor for next token prediction. + + """ + sp_size = dist.get_world_size(sp_group) + sp_rank = dist.get_rank(sp_group) + if isinstance(batch, torch.Tensor): + batch = [batch] + seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1 + + if sp_size > 1: + for idx, tensor in enumerate(batch): + assert ( + tensor.shape[seq_dim] // (sp_size * 2) > 1 and tensor.shape[seq_dim] % (sp_size * 2) == 0 + ), f"Bro, the seq length {tensor.shape[seq_dim]} for tensor {idx} can't be split by {sp_size * 2}!" + if is_label: + assert tensor.dim() == 2, "Label shape should be (B, Seqlen)" + tensor = torch.cat([tensor[:, 1:], torch.full_like(tensor[:, :1], -100)], dim=1) + + tensor = tensor.view( + *tensor.shape[:seq_dim], + 2 * sp_size, + tensor.shape[seq_dim] // (2 * sp_size), + *tensor.shape[seq_dim + 1 :], + ) + indices = torch.tensor([sp_rank, 2 * sp_size - 1 - sp_rank], device=tensor.device) + tensor = tensor.index_select(seq_dim, indices).contiguous() + # (B, 2, Sq // (2 * sp_size), ...) -> (B, Sq // sp_size, ...) + batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :]) + + if len(batch) == 1: + return batch[0] + return batch + + +def split_varlen_zigzag( + batch: Union[List[torch.Tensor], torch.Tensor], + cu_seqlens: torch.Tensor, + sp_group: ProcessGroup, + max_seqlen: int = 0, + is_2d: bool = False, + is_label: bool = False, +) -> Union[List[torch.Tensor], torch.Tensor]: + """Split each sequence in a batch of packed sequences in a zigzag fashion. + For each tensor in batch, return packed sequences if is_2d is False; + else return a padded batch of sequences. + + Args: + batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d. + cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting. + sp_group (ProcessGroup): The process group for sequence parallelism. + max_seqlen (int): The maximum sequence length in the batch before splitting. + is_2d (bool): If True, then input has batch size and sequence length split into two dimensions. + is_label (bool): If True, mask out the first token in each sequence (). + + Returns: + batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size) + or (B, max_seqlen // sp_size, ...) if is_2d + """ + sp_size = dist.get_world_size(sp_group) + sp_rank = dist.get_rank(sp_group) + if is_2d: + assert max_seqlen > 0, "max_seqlen must be provided for 2D input" + + if isinstance(batch, torch.Tensor): + batch = [batch] + for i, packed_seq in enumerate(batch): + device = packed_seq.device + dtype = packed_seq.dtype + + if is_2d: + assert max_seqlen % (sp_size * 2) == 0 + # Recreate a padded tensor with the new max seqlen + shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:]) + local_seq = torch.zeros(shape, dtype=dtype, device=device) + else: + total_seqlen = cu_seqlens[-1] + assert ( + total_seqlen % (2 * sp_size) == 0 + ), f"total_seqlen {total_seqlen} must be divisible by 2 * sp_size = {2 * sp_size}" + local_seq = [] + + for j in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[j], cu_seqlens[j + 1] + seqlen = end - start + assert ( + seqlen % (2 * sp_size) == 0 + ), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting" + + if is_2d: + seq = packed_seq[j][:seqlen] + if is_label: + # Shift one position to the right for next token prediction + seq = torch.cat([seq[1:], torch.tensor([-100], dtype=dtype, device=device)]) + + seq = seq.chunk(2 * sp_size, dim=0) + half = seqlen // sp_size // 2 + local_seq[j][:half] = seq[sp_rank] + local_seq[j][half : seqlen // sp_size] = seq[2 * sp_size - 1 - sp_rank] + else: + seq = packed_seq[start:end] + if is_label: + seq = torch.cat(seq[1:], torch.tensor([-100], dtype=dtype, device=device)) + seq = seq.chunk(sp_size * 2) + local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]]) + + if is_2d: + batch[i] = local_seq.contiguous() + else: + batch[i] = torch.cat(local_seq, dim=0) + + if len(batch) == 1: + batch = batch[0] + return batch + + +def is_share_sp_tp(sp_mode: str): + """sp_mode "ring" and "split_gather" use the TP group as SP group + to split both the vocab and sequence, so we must gather the sequence + to correctly get logits at each positions. + """ + return sp_mode in ["ring", "split_gather"] + + +class RingComm: + def __init__(self, process_group: dist.ProcessGroup): + self._process_group = process_group + self._ops = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = [] + + self.send_rank = (self.rank + 1) % self.world_size + self.recv_rank = (self.rank - 1) % self.world_size + + self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) + self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) + + def send_recv( + self, + send_tensor: torch.Tensor, + recv_tensor: Optional[torch.Tensor] = None, + commit: bool = True, + ) -> torch.Tensor: + if recv_tensor is None: + res = torch.empty_like(send_tensor) + else: + res = recv_tensor + + # looks like batch_isend_irecv doesn't deadlock even + # when we don't swap send recv ops based on rank + send_op = dist.P2POp(dist.isend, send_tensor, self.send_rank, group=self._process_group) + recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + self._ops.extend([send_op, recv_op]) + + if commit: + self._reqs = dist.batch_isend_irecv(self._ops) + return res + + def commit(self): + assert len(self._ops) > 0, "No ops to commit" + self._reqs = dist.batch_isend_irecv(self._ops) + + def wait(self): + assert len(self._reqs) > 0, "No requests to wait for" + for req in self._reqs: + req.wait() + self._reqs = [] + self._ops = [] + + +@torch.jit.script +def get_half_index(cu_seqlens, *, front: bool): + index = torch.zeros(cu_seqlens[-1], dtype=torch.bool, device=cu_seqlens.device) + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + if front: + end = (start + end) // 2 + else: + start = (start + end) // 2 + index[start:end] = True + return index diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 5b36fc7db..67c20eed8 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -26,6 +26,8 @@ from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, dist_cross_entropy +_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] + class CommandPipelineForwards: """ @@ -349,7 +351,7 @@ class CommandPipelineForwards: return {"hidden_states": hidden_states} -def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): +def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def forward( self, hidden_states: torch.Tensor, @@ -362,7 +364,7 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: if sp_mode is not None: - assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" + assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet" assert (sp_size is not None) and ( sp_group is not None ), "Must specify sp_size and sp_group for sequence parallel" @@ -459,7 +461,7 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None return forward -def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): +def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) def forward( diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 9ffbca517..af610500a 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -1,8 +1,9 @@ import math import warnings -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch +import torch.distributed import torch.nn.functional as F import torch.utils.checkpoint from torch import nn @@ -24,14 +25,14 @@ from transformers.models.llama.modeling_llama import ( from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.layer._operation import ( - all_to_all_comm, - gather_forward_split_backward, - split_forward_gather_backward, -) +from colossalai.shardformer.layer import AttnMaskType +from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward +from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, dist_cross_entropy +from ..layer import ColoAttention, RingAttention, dist_cross_entropy + +_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] class LlamaPipelineForwards: @@ -57,6 +58,10 @@ class LlamaPipelineForwards: hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, + # Split output only when computing cross entropy using llama_for_causal_lm_forward + # or get_lm_forward_with_dist_cross_entropy + # Default to True to avoid bug when calling classification forward from huggingface + force_sp_output_gather: bool = True, ): logger = logging.get_logger(__name__) @@ -97,7 +102,7 @@ class LlamaPipelineForwards: sp_group = shard_config.sequence_parallel_process_group sp_size = shard_config.sequence_parallel_size if sp_mode == "all_to_all" and not stage_manager.is_first_stage(): - # For correct positions ids. The states will be gather along the seq dim in the attention layer later. + # For generating full positions ids, as the states will be gather along the seq dim in the attention layer later. seq_length *= sp_size past_seen_tokens = 0 @@ -127,22 +132,36 @@ class LlamaPipelineForwards: position_ids = cache_position.unsqueeze(0) # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage - if shard_config.enable_flash_attention: + if not stage_manager.is_first_stage() and sp_mode == "ring_attn": + _, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group) + elif shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) - attention_mask = ColoAttention.prepare_attn_kwargs( + attn_kwargs = ColoAttention.prepare_attn_kwargs( mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True, + invert=(sp_mode != "ring_attn"), ) else: - attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) + attn_kwargs = self._update_causal_mask(attention_mask, hidden_states, cache_position) # Support SP + PP + # TODO: support padded casual cu_seqlens across stages if stage_manager.is_first_stage(): - if sp_mode in ["ring", "split_gather"]: + # Ring Attention zigzag batch processing + if sp_mode == "ring_attn": + assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." + if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: + hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( + attention_mask, sp_group, hidden_states, position_ids + ) + else: + hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group) + + elif is_share_sp_tp(sp_mode): hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) elif sp_mode == "all_to_all": hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size) @@ -177,12 +196,11 @@ class LlamaPipelineForwards: for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states += (hidden_states,) - if idx - start_idx < num_ckpt_layers: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + attn_kwargs, position_ids, past_key_values, output_attentions, @@ -192,14 +210,13 @@ class LlamaPipelineForwards: else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=attn_kwargs, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) - hidden_states = layer_outputs[0] if use_cache: @@ -209,10 +226,8 @@ class LlamaPipelineForwards: if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) - if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) - elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) # add hidden states from the last decoder layer if output_hidden_states: @@ -298,6 +313,15 @@ class LlamaPipelineForwards: logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False + if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: + # Split labels in a zigzag fashion too + sp_group = shard_config.sequence_parallel_process_group + if attention_mask.bool().all(): + labels = split_batch_zigzag(labels, sp_group, seq_dim=1) + else: + # [B, max_seqlen // sp_size] + labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = LlamaPipelineForwards.llama_model_forward( self.model, @@ -315,6 +339,7 @@ class LlamaPipelineForwards: hidden_states=hidden_states, stage_index=stage_index, shard_config=shard_config, + force_sp_output_gather=False, ) past_key_values = None @@ -457,11 +482,11 @@ class LlamaPipelineForwards: return {"hidden_states": hidden_states} -def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): +def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[Union[torch.Tensor, Dict]] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, @@ -470,7 +495,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: if sp_mode is not None: - assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" + assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet" assert (sp_size is not None) and ( sp_group is not None ), "Must specify sp_size and sp_group for sequence parallel" @@ -481,7 +506,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, bsz, q_len, _ = hidden_states.size() # sp: modify sp_len when sequence parallel mode is ring - if sp_mode in ["split_gather", "ring"]: + if is_share_sp_tp(sp_mode): q_len *= sp_size if self.config.pretraining_tp > 1: @@ -526,6 +551,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -537,12 +563,21 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - if shard_config.enable_flash_attention: + if sp_mode == "ring_attn": + attn_output = RingAttention.attention( + query_states, + key_states, + value_states, + sp_group, + **attention_mask, + inner_ring_size=shard_config.inner_ring_size, + ) + + elif shard_config.enable_flash_attention: assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) else: attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" @@ -588,7 +623,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, return forward -def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): +def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) def forward( @@ -603,6 +638,10 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size= output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + # Split output only when computing cross entropy using llama_for_causal_lm_forward + # or get_lm_forward_with_dist_cross_entropy + # Default to True to avoid bug when calling classification forward from huggingface + force_sp_output_gather: bool = True, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -629,32 +668,45 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size= past_seen_tokens = 0 seq_len = inputs_embeds.shape[1] + batch_size = inputs_embeds.shape[0] if use_cache: # kept for BC (cache positions) if not isinstance(past_key_values, StaticCache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() + if cache_position is None: if isinstance(past_key_values, StaticCache): raise ValueError("cache_position is a required argument when using StaticCache.") cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device) - if position_ids is None: position_ids = cache_position.unsqueeze(0) - # in this case, attention_mask is a dict rather than a tensor if shard_config.enable_flash_attention: - mask_shape = (inputs_embeds.shape[0], 1, seq_len, past_seen_tokens + seq_len) - attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len) + attn_kwargs: dict = ColoAttention.prepare_attn_kwargs( mask_shape, inputs_embeds.dtype, inputs_embeds.device, q_padding_mask=attention_mask, is_causal=True, + invert=(sp_mode != "ring_attn"), ) + else: - attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + + # Ring Attention zigzag batch processing + if sp_mode == "ring_attn": + assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." + if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: + inputs_embeds, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( + attention_mask, sp_group, inputs_embeds, position_ids + ) + else: + inputs_embeds, position_ids = split_batch_zigzag([inputs_embeds, position_ids], sp_group) + attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors - if sp_mode in ["ring", "split_gather"]: + elif is_share_sp_tp(sp_mode): inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) elif sp_mode == "all_to_all": inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) @@ -672,7 +724,7 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size= layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + attn_kwargs, position_ids, past_key_values, output_attentions, @@ -683,7 +735,7 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size= else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=attn_kwargs, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -700,11 +752,9 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size= all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) - - if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) - elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + # Cases that don't support parallelizing cross entropy computation along sequence + if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather: + hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) # add hidden states from the last decoder layer if output_hidden_states: @@ -777,6 +827,15 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: + # Special processing: Split labels in a zigzag fashion too + sp_group = shard_config.sequence_parallel_process_group + if attention_mask.bool().all(): + labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True) + else: + # [B, max_seq_len // sp_size] + labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -789,6 +848,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + force_sp_output_gather=False, ) hidden_states = outputs[0] @@ -799,7 +859,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): else: logits = self.lm_head(hidden_states) logits = logits.float() - loss = dist_cross_entropy( labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype ) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 282cf0464..7c1e6f0d7 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -75,6 +75,7 @@ class Policy(ABC): def __init__(self) -> None: self.shard_config: Optional[ShardConfig] = None self.model: Optional[Module] = None + self.is_causal = None # Whether we're doing causal lm, i.e. using cross entropy def set_model(self, model: nn.Module) -> None: r""" diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index a9b915d10..1efd3d017 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -69,13 +69,18 @@ class CommandPolicy(Policy): sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None sp_partial_derived = sp_mode in ["split_gather", "ring"] + if sp_mode == "ring_attn" and not self.is_causal: + raise ValueError("Ring attention is only meant for causal language modeling.") + tp_size = self.shard_config.tensor_parallel_size or None + num_q_heads = self.model.config.num_attention_heads + num_kv_heads = getattr(self.model.config, "num_key_value_heads", None) if sp_mode == "all_to_all": - decoder_attribute_replacement = { - "num_heads": self.model.config.num_attention_heads // sp_size, - } - if getattr(self.model.config, "num_key_value_heads", False): - decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size + num_q_heads //= sp_size + decoder_attribute_replacement = {"num_heads": num_q_heads} + if num_kv_heads: + num_kv_heads //= sp_size + decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads policy[attn_cls] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, @@ -104,21 +109,18 @@ class CommandPolicy(Policy): if self.shard_config.enable_tensor_parallelism: assert ( - self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + num_q_heads % tp_size == 0 ), f"The number of attention heads must be divisible by tensor parallel size." if hasattr(self.model.config, "num_key_value_heads"): assert ( - self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size - and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 + num_kv_heads >= tp_size and num_kv_heads % tp_size == 0 ), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size." decoder_attribute_replacement = { - "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attn.hidden_size": self.model.config.hidden_size // tp_size, + "self_attn.num_heads": num_q_heads // tp_size, } if getattr(self.model.config, "num_key_value_heads", False): - decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( - self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size - ) + decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads // tp_size policy[CohereDecoderLayer] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, @@ -290,10 +292,11 @@ class CommandForCausalLMPolicy(CommandPolicy): def module_policy(self): from transformers import CohereForCausalLM + self.is_causal = True policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm new_item = { CohereForCausalLM: ModulePolicyDescription( sub_module_replacement=[ diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index 605f69c4a..ea68649d5 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -298,7 +298,7 @@ class DeepseekForCausalLMPolicy(DeepseekPolicy): policy = super().module_policy() # TODO: assign pg mesh from plugin to all modules if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm new_item = { "DeepseekForCausalLM": ModulePolicyDescription( sub_module_replacement=[ diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 36491b4b5..f72a72df0 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -69,13 +69,20 @@ class LlamaPolicy(Policy): sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None sp_partial_derived = sp_mode in ["split_gather", "ring"] + if sp_mode == "ring_attn" and not self.is_causal: + raise ValueError("Ring attention is only meant for causal language modeling.") + + tp_size = self.shard_config.tensor_parallel_size + # Modified by SP and TP + num_q_heads = self.model.config.num_attention_heads + num_kv_heads = getattr(self.model.config, "num_key_value_heads", None) if sp_mode == "all_to_all": - decoder_attribute_replacement = { - "num_heads": self.model.config.num_attention_heads // sp_size, - } - if getattr(self.model.config, "num_key_value_heads", False): - decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size + num_q_heads //= sp_size + decoder_attribute_replacement = {"num_heads": num_q_heads} + if num_kv_heads: + num_kv_heads //= sp_size + decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads policy[attn_cls] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, @@ -104,21 +111,20 @@ class LlamaPolicy(Policy): if self.shard_config.enable_tensor_parallelism: assert ( - self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + num_q_heads % tp_size == 0 ), f"The number of attention heads must be divisible by tensor parallel size." if hasattr(self.model.config, "num_key_value_heads"): assert ( - self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size - and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 + num_kv_heads >= tp_size and num_kv_heads % tp_size == 0 ), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size." + num_q_heads //= tp_size decoder_attribute_replacement = { - "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attn.hidden_size": self.model.config.hidden_size // tp_size, + "self_attn.num_heads": num_q_heads, } if getattr(self.model.config, "num_key_value_heads", False): - decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( - self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size - ) + num_kv_heads //= tp_size + decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads policy[LlamaDecoderLayer] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, @@ -295,10 +301,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy): def module_policy(self): from transformers import LlamaForCausalLM + self.is_causal = True policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm new_item = { LlamaForCausalLM: ModulePolicyDescription( sub_module_replacement=[ @@ -313,10 +320,6 @@ class LlamaForCausalLMPolicy(LlamaPolicy): ], ) } - if self.shard_config.parallel_output: - new_item[LlamaForCausalLM].method_replacement = { - "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) - } else: new_item = { LlamaForCausalLM: ModulePolicyDescription( @@ -336,7 +339,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy): self.set_pipeline_forward( model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy ) - + elif self.shard_config.enable_tensor_parallelism or self.shard_config.enable_sequence_parallelism: + # Compute loss distributedly along the sequence dimension + new_item[LlamaForCausalLM].method_replacement = { + "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + } return policy def get_held_layers(self) -> List[Module]: diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index c5a0277a5..6ea27e210 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -271,7 +271,7 @@ class MistralForCausalLMPolicy(MistralPolicy): policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm new_item = { MistralForCausalLM: ModulePolicyDescription( sub_module_replacement=[ diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 10df143c9..e11edae9f 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -275,7 +275,7 @@ class MixtralForCausalLMPolicy(MixtralPolicy): policy = super().module_policy() # TODO: assign pg mesh from plugin to all modules if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm new_item = { MixtralForCausalLM: ModulePolicyDescription( sub_module_replacement=[ diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 362c14060..235dc7d56 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -313,7 +313,7 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy): setattr(self.shard_config, "causal_lm", True) if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm new_item = { Qwen2ForCausalLM: ModulePolicyDescription( sub_module_replacement=[ diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 163d7a7bb..70eb271c9 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -10,7 +10,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from .grad_ckpt_config import GradientCheckpointConfig __all__ = ["ShardConfig"] -SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] +SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"] @dataclass @@ -29,6 +29,8 @@ class ShardConfig: enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. + parallel_output (bool): For TP: whether to use parallelize cross entropy computation along the feature dim. + For SP: set to True to NOT gather the output along the seq dim. """ tensor_parallel_process_group: Optional[ProcessGroup] = None @@ -47,10 +49,11 @@ class ShardConfig: gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) + # For ring attention + inner_ring_size: Optional[int] = None # for moe related moe_dp_group: Optional[ProcessGroup] = None ep_group: Optional[ProcessGroup] = None - # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] @@ -80,9 +83,9 @@ class ShardConfig: self.enable_tensor_parallelism ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True" elif self.sequence_parallelism_mode in ["all_to_all"]: - assert ( - not self.enable_tensor_parallelism - ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False" + # assert ( + # not self.enable_tensor_parallelism + # ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False" if self.enable_sequence_overlap: self.enable_sequence_overlap = False warnings.warn( diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index e530e2d6a..093377e7a 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -28,6 +28,7 @@ warnings.filterwarnings("ignore") # Constants # ============================== +# We have lots of llamas for your choice! MODEL_CONFIGS = { "100m": LlamaConfig( max_position_embeddings=4096, @@ -36,6 +37,7 @@ MODEL_CONFIGS = { intermediate_size=2048, hidden_size=1024, ), + "5b": LlamaConfig(max_position_embeddings=4096, num_key_value_heads=8), "7b": LlamaConfig(max_position_embeddings=4096), "13b": LlamaConfig( hidden_size=5120, @@ -68,9 +70,6 @@ def main(): default="gemini", help="Choose which plugin to use", ) - parser.add_argument( - "--overlap", action="store_true", help="Overlap communication with computation in Pipeline Parallel." - ) parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size") parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run") parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore") @@ -94,11 +93,24 @@ def main(): parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) - parser.add_argument("--profile", action="store_true", help="Profile the code", default=False) + parser.add_argument("--profile", action="store_true", help="Profile the code") + parser.add_argument( + "--nsys", + action="store_true", + help="Use nsys for profiling. \ + You should put something like this before colossalai launch: \ + nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out", + ) parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation") parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") parser.add_argument("--no_cache", action="store_true") parser.add_argument("--overlap_allgather", action="store_true") + parser.add_argument( + "--sp_mode", + default="all_to_all", + choices=["all_to_all", "ring_attn", "ring", "split_gather"], + help="Sequence parallelism mode", + ) args = parser.parse_args() colossalai.launch_from_torch() @@ -195,12 +207,12 @@ def main(): num_model_chunks=args.n_chunks, zero_stage=args.zero, sp_size=args.sp, + sequence_parallelism_mode=args.sp_mode, enable_sequence_parallelism=args.sp > 1, enable_fused_normalization=torch.cuda.is_available(), enable_flash_attention=args.xformers, microbatch_size=args.mbs, precision="bf16", - overlap_p2p=args.overlap, enable_metadata_cache=not args.no_cache, overlap_allgather=args.overlap_allgather, **hybrid_kwargs, @@ -218,7 +230,6 @@ def main(): microbatch_size=args.mbs, initial_scale=2**8, precision="bf16", - overlap_p2p=args.overlap, ) else: raise ValueError(f"Unknown plugin {args.plugin}") @@ -295,6 +306,7 @@ def main(): args.ignore_steps, 1, # avoid creating massive log files save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", + nsys=args.nsys, ) as prof: if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: data_iter = iter(dataloader) @@ -320,13 +332,16 @@ def main(): performance_evaluator.on_step_start(step) outputs = model(**batch) loss = outputs[0] + del outputs # free memory + + if dist.get_rank() == dist.get_world_size() - 1: + print(f"Step {step} loss: {loss}") booster.backward(loss, optimizer) optimizer.step() optimizer.zero_grad() performance_evaluator.on_step_end(**batch) prof.step() - performance_evaluator.on_fit_end() coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") diff --git a/examples/language/opt/README.md b/examples/language/opt/README.md index af1e79437..694c5cf91 100644 --- a/examples/language/opt/README.md +++ b/examples/language/opt/README.md @@ -17,7 +17,7 @@ limitations under the License. ## OPT Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments. -The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. +The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Causal Language Modelling at low cost. ## Our Modifications diff --git a/examples/language/performance_evaluator.py b/examples/language/performance_evaluator.py index ca4a02cd2..f5ad1d23d 100644 --- a/examples/language/performance_evaluator.py +++ b/examples/language/performance_evaluator.py @@ -28,7 +28,7 @@ def all_reduce_mean(x: float, world_size: int) -> float: return tensor.item() -def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): +def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir, nsys=False): class DummyProfiler: def __init__(self): self.step_number = 0 @@ -42,7 +42,29 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): def __exit__(self, exc_type, exc_value, traceback): pass + class NsysProfiler: + def __init__(self, warmup_steps, active_steps): + self.step_number = 0 + self.warmup_steps = warmup_steps + self.active_steps = active_steps + + def step(self): + if self.step_number == self.warmup_steps: + torch.cuda.cudart().cudaProfilerStart() + elif self.step_number == self.warmup_steps + self.active_steps: + torch.cuda.cudart().cudaProfilerStop() + self.step_number += 1 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + if enable_flag: + if nsys: + return NsysProfiler(warmup_steps, active_steps) + return profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), diff --git a/examples/tutorial/opt/opt/README.md b/examples/tutorial/opt/opt/README.md index a01209cbd..3776e0c64 100644 --- a/examples/tutorial/opt/opt/README.md +++ b/examples/tutorial/opt/opt/README.md @@ -19,7 +19,7 @@ limitations under the License. ## OPT Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments. -The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. +The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning causal Language Modelling at low cost. We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling). diff --git a/extensions/pybind/flash_attention/flash_attention_dao_cuda.py b/extensions/pybind/flash_attention/flash_attention_dao_cuda.py index a108377a8..560d952f6 100644 --- a/extensions/pybind/flash_attention/flash_attention_dao_cuda.py +++ b/extensions/pybind/flash_attention/flash_attention_dao_cuda.py @@ -57,14 +57,14 @@ class FlashAttentionDaoCudaExtension(_Extension): q_indices: Optional[torch.Tensor] = None, kv_indices: Optional[torch.Tensor] = None, ): - # [B, N, S, D] -> [B, S, N, D] + # [B, H, S, D] -> [B, S, H, D] q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) b, s_q = q.shape[:2] if cu_seqlens_q is not None: # padded / padded causal - # unpad input: [B, S, N, D] -> [T, N, D] + # unpad input: [B, S, H, D] -> [T, H, D] q = _unpad_input(q, q_indices) kv = _unpad_input(torch.stack(tensors=(k, v), dim=2), kv_indices) attn_output = flash_attn_varlen_kvpacked_func( @@ -78,7 +78,7 @@ class FlashAttentionDaoCudaExtension(_Extension): softmax_scale=scale, causal=is_causal, ) - # pad output: [T, N, D] -> [B, S, N, D] + # pad output: [T, H, D] -> [B, S, H, D] attn_output = pad_input(attn_output, q_indices, b, s_q) else: # causal / no attn mask @@ -90,7 +90,7 @@ class FlashAttentionDaoCudaExtension(_Extension): softmax_scale=scale, causal=is_causal, ) - # [B, S, N, D] -> [B, N, S, D] + # [B, S, H, D] -> [B, H, S, D] return attn_output.transpose(1, 2) return flash_attention diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py index 66c794a7d..9c1a11e7b 100644 --- a/tests/kit/model_zoo/__init__.py +++ b/tests/kit/model_zoo/__init__.py @@ -22,9 +22,9 @@ COMMON_MODELS = [ "transformers_bloom_for_causal_lm", "transformers_falcon_for_causal_lm", "transformers_chatglm_for_conditional_generation", - "transformers_llama_for_casual_lm", + "transformers_llama_for_causal_lm", "transformers_vit_for_masked_image_modeling", - "transformers_mistral_for_casual_lm", + "transformers_mistral_for_causal_lm", ] IS_FAST_TEST = os.environ.get("FAST_TEST", "0") == "1" diff --git a/tests/kit/model_zoo/transformers/command.py b/tests/kit/model_zoo/transformers/command.py index a8b8842c5..3f4ea4583 100644 --- a/tests/kit/model_zoo/transformers/command.py +++ b/tests/kit/model_zoo/transformers/command.py @@ -32,8 +32,8 @@ if HAS_COMMAND: return dict(input_ids=input_ids, attention_mask=attention_mask) - # label is needed for casual lm - def data_gen_for_casual_lm(): + # label is needed for causal lm + def data_gen_for_causal_lm(): data = data_gen() labels = data["input_ids"].clone() data["labels"] = labels @@ -44,7 +44,7 @@ if HAS_COMMAND: # function to get the loss loss_fn = lambda output: output["last_hidden_state"].mean() - loss_fn_for_casual_lm = lambda output: output["loss"] + loss_fn_for_causal_lm = lambda output: output["loss"] loss_fn_for_seq_classification = lambda output: output["logits"].mean() config = CohereConfig( @@ -70,10 +70,10 @@ if HAS_COMMAND: model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( - name="transformers_command_for_casual_lm", + name="transformers_command_for_causal_lm", model_fn=lambda: transformers.CohereForCausalLM(config), - data_gen_fn=data_gen_for_casual_lm, + data_gen_fn=data_gen_for_causal_lm, output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_casual_lm, + loss_fn=loss_fn_for_causal_lm, model_attribute=ModelAttribute(has_control_flow=True), ) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 61fa56050..05ac9d8d2 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -33,20 +33,21 @@ if HAS_LLAMA: [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], ] ).long() - - attention_mask = torch.Tensor( - [ - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - ] - ).long() - + attention_mask = torch.ones_like(input_ids) return dict(input_ids=input_ids, attention_mask=attention_mask) - # label is needed for casual lm - def data_gen_for_casual_lm(): + # label is needed for causal lm + def data_gen_for_causal_lm(): data = data_gen() + + # Test padded sequence + padding = torch.zeros(2, data["input_ids"].shape[1] // 2, dtype=torch.long) + data["input_ids"] = torch.cat([data["input_ids"], padding], dim=1) + data["attention_mask"] = torch.cat([data["attention_mask"], padding], dim=1) + + ignore_idx = -100 labels = data["input_ids"].clone() + labels[~data["attention_mask"].bool()] = ignore_idx data["labels"] = labels return data @@ -55,7 +56,7 @@ if HAS_LLAMA: # function to get the loss loss_fn = lambda output: output["last_hidden_state"].mean() - loss_fn_for_casual_lm = lambda output: output["loss"] + loss_fn_for_causal_lm = lambda output: output["loss"] loss_fn_for_seq_classification = lambda output: output["logits"].mean() config = LlamaConfig( @@ -70,23 +71,23 @@ if HAS_LLAMA: config.pad_token_id = config.eos_token_id # register the following models - # transformers.LlamaModel, # transformers.LlamaForCausalLM, + # transformers.LlamaModel, # transformers.LlamaForSequenceClassification, model_zoo.register( - name="transformers_llama", - model_fn=lambda: transformers.LlamaModel(config), - data_gen_fn=data_gen, + name="transformers_llama_for_causal_lm", + model_fn=lambda: transformers.LlamaForCausalLM(config), + data_gen_fn=data_gen_for_causal_lm, output_transform_fn=output_transform_fn, - loss_fn=loss_fn, + loss_fn=loss_fn_for_causal_lm, model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( - name="transformers_llama_for_casual_lm", - model_fn=lambda: transformers.LlamaForCausalLM(config), - data_gen_fn=data_gen_for_casual_lm, + name="transformers_llama", + model_fn=lambda: transformers.LlamaModel(config), + data_gen_fn=data_gen, output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_casual_lm, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( diff --git a/tests/kit/model_zoo/transformers/mistral.py b/tests/kit/model_zoo/transformers/mistral.py index ae5a97002..43fc662cc 100644 --- a/tests/kit/model_zoo/transformers/mistral.py +++ b/tests/kit/model_zoo/transformers/mistral.py @@ -64,7 +64,7 @@ model_zoo.register( model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( - name="transformers_mistral_for_casual_lm", + name="transformers_mistral_for_causal_lm", model_fn=lambda: transformers.MistralForCausalLM(config), data_gen_fn=data_gen_for_lm, output_transform_fn=output_transform_fn, diff --git a/tests/kit/model_zoo/transformers/qwen2.py b/tests/kit/model_zoo/transformers/qwen2.py index 1c26af698..83bc9f941 100644 --- a/tests/kit/model_zoo/transformers/qwen2.py +++ b/tests/kit/model_zoo/transformers/qwen2.py @@ -33,8 +33,8 @@ if HAS_QWEN2: attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long() return dict(input_ids=input_ids, attention_mask=attention_mask) - # label is needed for casual lm - def data_gen_for_casual_lm(): + # label is needed for causal lm + def data_gen_for_causal_lm(): data = data_gen() labels = data["input_ids"].clone() data["labels"] = labels @@ -45,7 +45,7 @@ if HAS_QWEN2: # function to get the loss loss_fn = lambda output: output["last_hidden_state"].mean() - loss_fn_for_casual_lm = lambda output: output["loss"] + loss_fn_for_causal_lm = lambda output: output["loss"] loss_fn_for_seq_classification = lambda output: output["logits"].mean() config = Qwen2Config( @@ -72,11 +72,11 @@ if HAS_QWEN2: model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( - name="transformers_qwen2_for_casual_lm", + name="transformers_qwen2_for_causal_lm", model_fn=lambda: transformers.Qwen2ForCausalLM(config), - data_gen_fn=data_gen_for_casual_lm, + data_gen_fn=data_gen_for_causal_lm, output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_casual_lm, + loss_fn=loss_fn_for_causal_lm, model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index e57cadfd8..3e8532955 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -97,7 +97,7 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True): # TODO(ver217): add more models for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry( - "transformers_llama_for_casual_lm" + "transformers_llama_for_causal_lm" ).items(): err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 8c59f430c..c2a08a541 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -105,7 +105,7 @@ def check_low_level_zero_lora(stage, model_name, early_stop: bool = True): sub_model_zoo = model_zoo.get_sub_registry(model_name) for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): task_type = None - if name == "transformers_llama_for_casual_lm": + if name == "transformers_llama_for_causal_lm": task_type = "CAUSAL_LM" if name == "transformers_llama_for_sequence_classification": task_type = "SEQ_CLS" diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index fd13ce0bf..b133be948 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -74,7 +74,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b @clear_cache_before_run() @parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS) @parameterize("shard", [True, False]) -@parameterize("model_name", ["transformers_llama_for_casual_lm"]) +@parameterize("model_name", ["transformers_llama_for_causal_lm"]) @parameterize("size_per_shard", [32]) @parameterize("tp_size", [1, 2]) @parameterize("zero_size", [2]) diff --git a/tests/test_checkpoint_io/test_gemini_torch_compability.py b/tests/test_checkpoint_io/test_gemini_torch_compability.py index 4897907ff..ce4d10322 100644 --- a/tests/test_checkpoint_io/test_gemini_torch_compability.py +++ b/tests/test_checkpoint_io/test_gemini_torch_compability.py @@ -20,7 +20,7 @@ from tests.kit.model_zoo import model_zoo @clear_cache_before_run() @parameterize("shard", [False, True]) -@parameterize("model_name", ["transformers_llama_for_casual_lm"]) +@parameterize("model_name", ["transformers_llama_for_causal_lm"]) def exam_torch_load_from_gemini(shard: bool, model_name: str): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 4f8f26041..86d7924fb 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -39,7 +39,7 @@ else: @parameterize("shard", [True, False]) -@parameterize("model_name", ["transformers_llama_for_casual_lm"]) +@parameterize("model_name", ["transformers_llama_for_causal_lm"]) @parameterize("size_per_shard", [32]) @parameterize("test_config", TEST_CONFIGS) @clear_cache_before_run() diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index ab48944d4..a8e05a25a 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -149,7 +149,7 @@ def check_low_level_zero_lora_checkpointIO( if name != "transformers_llama": continue task_type = None - if name == "transformers_llama_for_casual_lm": + if name == "transformers_llama_for_causal_lm": task_type = "CAUSAL_LM" if name == "transformers_llama_for_sequence_classification": task_type = "SEQ_CLS" diff --git a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py index df8636141..6f8eb2ad2 100644 --- a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py +++ b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py @@ -18,7 +18,7 @@ from tests.kit.model_zoo import model_zoo @clear_cache_before_run() -@parameterize("model_name", ["transformers_llama_for_casual_lm"]) +@parameterize("model_name", ["transformers_llama_for_causal_lm"]) @parameterize("plugin_type", ["ddp", "zero", "gemini"]) def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32): (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( diff --git a/tests/test_lora/test_lora.py b/tests/test_lora/test_lora.py index 1ae17025d..b0ec767cc 100644 --- a/tests/test_lora/test_lora.py +++ b/tests/test_lora/test_lora.py @@ -91,7 +91,7 @@ def run_lora_test(): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): task_type = None - if name == "transformers_llama_for_casual_lm": + if name == "transformers_llama_for_causal_lm": task_type = "CAUSAL_LM" if name == "transformers_llama_for_sequence_classification": task_type = "SEQ_CLS" diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index a626b834a..04a1296e6 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -6,6 +6,7 @@ import pytest import torch import torch.distributed as dist import torch.nn as nn +from torch.testing import assert_close import colossalai from colossalai.cluster import ProcessGroupMesh @@ -107,13 +108,13 @@ def run_pp( # check loss if stage_manager.is_last_stage(ignore_chunk=True): - assert torch.allclose(torch_loss, pp_ret["loss"]) + assert_close(torch_loss, pp_ret["loss"]) # check gradients for i in range(num_model_chunk): idx = world_size * i + rank - assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) - assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) + assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) + assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) # step torch_optimizer.step() @@ -123,8 +124,8 @@ def run_pp( # check updated param for i in range(num_model_chunk): idx = world_size * i + rank - assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight) - assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias) + assert_close(torch_model.layers[idx].weight, sharded_model[i].weight) + assert_close(torch_model.layers[idx].bias, sharded_model[i].bias) # forward only with torch.no_grad(): @@ -135,14 +136,14 @@ def run_pp( sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True ) if stage_manager.is_last_stage(ignore_chunk=True): - assert torch.allclose(torch_loss, pp_ret["loss"]) + assert_close(torch_loss, pp_ret["loss"]) for layer in sharded_model: if layer.weight.grad is None: assert layer.weight.grad is None and layer.bias.grad is None else: - assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad)) - assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad)) + assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad)) + assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad)) @pytest.mark.dist diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py index c4bfa7b69..8ae4f6daa 100644 --- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -6,6 +6,7 @@ import pytest import torch import torch.distributed as dist import torch.nn as nn +from torch.testing import assert_close import colossalai from colossalai.cluster import ProcessGroupMesh @@ -103,13 +104,13 @@ def examine_pp(num_microbatch: int, batch_size: int): # check loss if stage_manager.is_last_stage(): - assert torch.allclose(torch_loss, pp_ret["loss"]) + assert_close(torch_loss, pp_ret["loss"]) # check gradients for i in range(len(sharded_model)): idx = rank * num_local_layer + i - assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) - assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) + assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) + assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) # step torch_optimizer.step() @@ -119,8 +120,8 @@ def examine_pp(num_microbatch: int, batch_size: int): # check updated param for i in range(len(sharded_model)): idx = rank * num_local_layer + i - assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight) - assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias) + assert_close(torch_model.layers[idx].weight, sharded_model[i].weight) + assert_close(torch_model.layers[idx].bias, sharded_model[i].bias) # forward only with torch.no_grad(): @@ -131,14 +132,14 @@ def examine_pp(num_microbatch: int, batch_size: int): sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True ) if stage_manager.is_last_stage(): - assert torch.allclose(torch_loss, pp_ret["loss"]) + assert_close(torch_loss, pp_ret["loss"]) for layer in sharded_model: if layer.weight.grad is None: assert layer.weight.grad is None and layer.bias.grad is None else: - assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad)) - assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad)) + assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad)) + assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad)) def run_dist( diff --git a/tests/test_shardformer/test_flash_attention.py b/tests/test_shardformer/test_flash_attention.py index 9aa24a166..42ca6b198 100644 --- a/tests/test_shardformer/test_flash_attention.py +++ b/tests/test_shardformer/test_flash_attention.py @@ -88,6 +88,7 @@ def check_attn_func(dtype: torch.dtype, attn_func, attn_kwargs: dict, padding_ma padding_mask = padding_mask[:, None, :, None].logical_not() ref_output = ref_output.masked_fill(padding_mask, 0) output = output.masked_fill(padding_mask, 0) + assert_close(output, ref_output, **tols) output.mean().backward() ref_output.mean().backward() @@ -128,6 +129,8 @@ def test_flash_attn_func(dtype: torch.dtype): attn_kwargs, padding_mask = gen_kwargs_func(dtype) for attn_func, name, need_postprocess in attn_funcs: print(f"{dtype}, {name}, {mask_type}") + if mask_type == "padded": + pass if need_postprocess: check_attn_func(dtype, attn_func, post_process_kwargs_for_raw_attn(attn_kwargs), padding_mask) else: diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py new file mode 100644 index 000000000..1c7647a7d --- /dev/null +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -0,0 +1,186 @@ +import torch +import torch.distributed as dist +import torch.nn.functional as F +from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer import AttnMaskType +from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention +from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device + + +@parameterize("seq_len", [4096]) +@parameterize("bs", [2]) +@parameterize("nheads", [5]) +@parameterize("d", [128]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) +def check_ring_attn(seq_len, bs, nheads, d, dtype): + torch.cuda.manual_seed(2) + device = get_current_device() + sp_group = dist.group.WORLD + sp_size = dist.get_world_size() + # Some outliers may seem large, but our errors are still lower than + # than Megatron-LM context parallel's + # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) + # and the original zigzag implementation's (https://github.com/zhuzilin/ring-flash-attention/tree/main) + atol = rtol = 7e-3 + + # Setup inputs + qkv = torch.randn(bs, seq_len, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + local_qkv = split_batch_zigzag(qkv, sp_group) + q, k, v = local_qkv.unbind(dim=-3) + q, k, v = [x.squeeze(2).detach().clone().transpose(1, 2) for x in (q, k, v)] # (B, nHeads, Sq, D) + q.requires_grad = k.requires_grad = v.requires_grad = True + + # Ring attention vs single GPU + ring_out, ring_lse = RingAttention.attention( + q, + k, + v, + sp_group, + AttnMaskType.CAUSAL, + return_softmax=True, + inner_ring_size=max(2, sp_size // 2), + # inner_ring_size=4 + ) + ring_out = ring_out.transpose(1, 2) + out, lse, _ = flash_attn_qkvpacked_func( + qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True + ) + + # Checkout out and softmax denominator + local_out = split_batch_zigzag(out, sp_group) + local_lse = split_batch_zigzag(lse, sp_group, seq_dim=-1) + local_lse = local_lse.transpose(1, 2).contiguous().view(-1, ring_lse.shape[-1]) # (B, nHeads, Sq) -> (T, nHeads) + assert_close(ring_lse, local_lse, atol=atol, rtol=rtol) + assert_close(ring_out, local_out, atol=atol, rtol=rtol) + + # Check grads + ring_out.sum().backward() + out.sum().backward() + ring_dq, ring_dk, ring_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)] + dqkv = qkv.grad + local_dqkv = split_batch_zigzag(dqkv, sp_group) + + assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol) + assert_close(ring_dk, local_dqkv[:, :, 1], atol=atol, rtol=rtol) + assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol) + if dist.get_rank() == 0: + print( + f"sp_size {dist.get_world_size()}, inner ring size {dist.get_world_size(RingAttention.INNER_RING_GROUP)} passed." + ) + + +@parameterize("seqlen", [4096]) +@parameterize("bs", [2]) +@parameterize("nheads", [5]) +@parameterize("d", [128]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) +def check_packed_seq(seqlen, bs, nheads, d, dtype): + device = get_current_device() + sp_group = dist.group.WORLD + sp_size = dist.get_world_size() + atol = rtol = 7e-3 + torch.cuda.manual_seed(2) + # Prepare varlen attention mask + padding_mask = torch.ones((bs, seqlen), dtype=torch.int, device=device) + padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0 + padding_mask[:, seqlen // 2 :] = 0 + + input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) + + # Forward + # out = ColoAttention.attention(q, k, v, **mask_info) + flat_input = input_embeds.view(-1, nheads, d)[padding_mask.flatten().nonzero().squeeze()] + qkv = torch.stack([flat_input] * 3, dim=1) + qkv.retain_grad() + + input_embeds, mask_info, _ = RingAttention.prepare_varlen_batch(padding_mask, sp_group, input_embeds) + out, lse, _ = flash_attn_varlen_qkvpacked_func( + qkv, + mask_info["cu_seqlens"] * sp_size, + mask_info["max_seqlen"] * sp_size, + return_attn_probs=True, + causal=True, + # deterministic=True + ) + # Test the splitting function + local_input = split_varlen_zigzag( + flat_input, mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size + ) + assert (local_input == input_embeds.view(-1, nheads, d)[mask_info["valid_indices"]]).all() + del local_input, flat_input + + q_ring, k_ring, v_ring = [input_embeds.clone().transpose(1, 2) for _ in range(3)] + q_ring.retain_grad() + k_ring.retain_grad() + v_ring.retain_grad() + + ring_out, ring_lse = RingAttention.attention( + q_ring, + k_ring, + v_ring, + sp_group, + **mask_info, + pad_output=False, + return_softmax=True, + # deterministic=True + ) + ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d) + # Check output + lse = lse.transpose(0, 1) + out, lse = split_varlen_zigzag( + [out, lse], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size + ) + assert_close(lse, ring_lse, atol=atol, rtol=rtol) + assert_close(out, ring_out, atol=atol, rtol=rtol) + + # Check grads + labels = torch.ones(out.shape[0], dtype=dtype, device=device) + F.mse_loss(out.sum((-2, -1)), labels).backward() + F.mse_loss(ring_out.sum((-2, -1)), labels[: ring_out.shape[0]]).backward() + dq, dk, dv = [ + split_varlen_zigzag( + qkv.grad[:, i], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size + ) + for i in range(3) + ] + dq_ring, dk_ring, dv_ring = [ + x.transpose(1, 2).reshape(-1, nheads, d)[mask_info["valid_indices"]] + for x in (q_ring.grad, k_ring.grad, v_ring.grad) + ] + + assert_close(dq, dq_ring, atol=atol, rtol=rtol) + assert_close(dk, dk_ring, atol=atol, rtol=rtol) + assert_close(dv, dv_ring, atol=atol, rtol=rtol) + + +def launch_single_ring(rank, world_size, port): + colossalai.launch(rank, world_size, "localhost", port) + check_packed_seq() + check_ring_attn() + + +def launch_double_ring(rank, world_size, port): + colossalai.launch(rank, world_size, "localhost", port) + check_ring_attn() + + +@rerun_if_address_is_in_use() +@parameterize("world_size", [2]) +def test_ring_attn(world_size): + spawn(launch_single_ring, nprocs=world_size) + + +@rerun_if_address_is_in_use() +@parameterize("world_size", [4]) +def test_double_ring(world_size): + spawn(launch_double_ring, nprocs=world_size) + + +if __name__ == "__main__": + test_ring_attn() + test_double_ring() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 190fee129..9ad84341a 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup from torch.nn import Module from torch.optim import Adam, Optimizer from torch.testing import assert_close +from transformers.modeling_outputs import BaseModelOutputWithPast from colossalai.accelerator import get_accelerator from colossalai.booster import Booster @@ -259,7 +260,6 @@ def run_forward_backward_with_hybrid_plugin( org_output = org_model(**unshard_test_data) org_loss = criterion(org_output) org_loss.backward() - return org_loss, org_output, sharded_loss, sharded_output @@ -302,11 +302,12 @@ def run_forward_backward_with_low_level_zero_plugin( def check_output_hidden_state( - org_output: Tensor, - sharded_output: Tensor, + org_output: BaseModelOutputWithPast, + sharded_output: BaseModelOutputWithPast, stage_manager: Optional[PipelineStageManager] = None, atol: float = 1e-5, rtol: float = 1e-3, + shard_config: Optional[ShardConfig] = None, ): org_hidden_state = org_output.last_hidden_state @@ -315,6 +316,14 @@ def check_output_hidden_state( else: sharded_hidden_state = sharded_output.last_hidden_state + # Check if the output sequence is gathered before cross entropy + if shard_config is not None: + seq_dim = 1 + sp_group = shard_config.sequence_parallel_process_group + sp_size = shard_config.sequence_parallel_size + if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size: + org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)] + assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol) @@ -374,8 +383,11 @@ def get_grad_tensors_for_check( shard_grad = torch.cat(shard_grad_list, dim=dim) # embedding may be resized when using tensor parallel - if shard_grad.shape[0] > org_grad.shape[0]: - shard_grad = shard_grad[: org_grad.shape[0], :] + try: + if shard_grad.shape[0] > org_grad.shape[0]: + shard_grad = shard_grad[: org_grad.shape[0], :] + except: + pass if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") @@ -404,9 +416,6 @@ def check_grad( org_grad = getattr_(org_model, suffix).weight.grad shard_grad = getattr_(sharded_model, suffix).weight.grad shard_weight = getattr_(sharded_model, suffix).weight - # if verbose and dist.get_rank() == 0: - # print("shard_weight", shard_weight) - # print("org_grad", org_grad) if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))] dist.all_gather(shard_grad_list, shard_grad, tp_group) @@ -440,7 +449,7 @@ def check_all_grad_tensors(check_tensors): "org_grad": tensor to be compared from the original model "shard_grad": tensor to be compared from the sharded model """ - for suffix, check_info in check_tensors.items(): + for idx, (suffix, check_info) in enumerate(check_tensors.items()): org_grad = check_info["org_grad"] shard_grad = check_info["shard_grad"] rtol = check_info["rtol"] diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index 3281b50e1..efe5cee2a 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -271,7 +271,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ], ) def run_command_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") + sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -321,7 +321,7 @@ def run_command_test(test_config): ], ) def run_command_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") + sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 88e54176b..3c66f6097 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -63,7 +63,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): master2working = sharded_optimizer.get_master_to_working_map() - for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): + for (name, p1), p2 in zip( + llama_model.named_parameters(), sharded_optimizer._master_param_groups_of_current_rank[0] + ): working_p = master2working[id(p2)] grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grad_index = ( @@ -73,7 +75,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) grad = grads[grad_index] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] - assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) + try: + assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) + except Exception as e: + raise RuntimeError(f"Failed to check grad for {name}") from e # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} @@ -114,75 +119,103 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == "LlamaModel": - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - + check_output_hidden_state( + org_output, + sharded_output, + stage_manager, + atol=atol, + rtol=rtol, + shard_config=booster.plugin.shard_config, + ) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # check weights if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): if test_config["precision"] == "fp32": atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - try: - check_weight( - llama_model, - shard_llama_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False, - ) - except Exception as e: - print(f"Failed config: {test_config}") - raise e + check_weight( + llama_model, + shard_llama_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) # check grads check_all_grad_tensors(grads_to_check) - torch.cuda.empty_cache() @parameterize( "test_config", [ - { # Ulysess + Flash attention + # Double Ring Attention + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 4, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "use_lazy_init": True, + "zero_stage": 0, + "precision": "fp16", + "initial_scale": 1, + "inner_ring_size": 2, + }, + # Ring Attention + PP + { "tp_size": 1, "pp_size": 2, "sp_size": 2, "num_microbatches": 2, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_flash_attention": True, + "sequence_parallelism_mode": "ring_attn", "use_lazy_init": True, - "zero_stage": 0, + "zero_stage": 1, "precision": "fp16", "initial_scale": 1, }, - { # Test ring + Flash attention + # Ring Attention + TP + { "tp_size": 2, "pp_size": 1, "sp_size": 2, "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "enable_flash_attention": True, + "sequence_parallelism_mode": "ring_attn", "use_lazy_init": True, "zero_stage": 2, "precision": "fp16", "initial_scale": 1, }, - { - "tp_size": 1, + { # Ulysess + TP + "tp_size": 2, "pp_size": 1, "sp_size": 2, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", + "enable_all_optimization": True, "use_lazy_init": True, - "zero_stage": 1, + "zero_stage": 0, + "precision": "fp16", + "initial_scale": 1, + }, + { # Ulysess + PP + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 0, "precision": "fp16", "initial_scale": 1, }, @@ -192,8 +225,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": False, + "enable_flash_attention": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 1, + "sp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": True, "use_lazy_init": True, + "zero_stage": 2, "precision": "fp16", "initial_scale": 1, }, @@ -240,12 +286,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) def run_llama_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: + continue try: check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) except Exception as e: - print(f"Failed config: {test_config}") + print(f"Failed config: {test_config}, model name: {name}") raise e clear_layout_converter() Randomizer.reset_index()