From 8e412a548e5366d1c42bcf386bd185091bd0c280 Mon Sep 17 00:00:00 2001 From: Zhongkai Zhao Date: Wed, 3 Apr 2024 17:15:47 +0800 Subject: [PATCH] [shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 --- .../booster/plugin/hybrid_parallel_plugin.py | 89 +++- .../plugin/moe_hybrid_parallel_plugin.py | 4 + colossalai/cluster/process_group_mesh.py | 37 +- colossalai/shardformer/layer/__init__.py | 2 + colossalai/shardformer/layer/_operation.py | 389 ++++++++++++++++-- colossalai/shardformer/layer/linear.py | 49 ++- .../shardformer/layer/qkv_fused_linear.py | 44 +- colossalai/shardformer/layer/utils.py | 26 +- colossalai/shardformer/modeling/bert.py | 20 +- colossalai/shardformer/modeling/bloom.py | 18 +- colossalai/shardformer/modeling/chatglm2.py | 22 +- colossalai/shardformer/modeling/gpt2.py | 30 +- colossalai/shardformer/modeling/llama.py | 301 +++++++++++++- colossalai/shardformer/policies/bert.py | 30 +- colossalai/shardformer/policies/bloom.py | 27 +- colossalai/shardformer/policies/chatglm2.py | 22 +- colossalai/shardformer/policies/gpt2.py | 40 +- colossalai/shardformer/policies/llama.py | 87 +++- colossalai/shardformer/shard/shard_config.py | 63 ++- colossalai/zero/low_level/low_level_optim.py | 2 +- tests/kit/model_zoo/transformers/gpt.py | 46 ++- tests/kit/model_zoo/transformers/llama.py | 14 +- .../test_gemini_checkpoint_io.py | 17 +- tests/test_cluster/test_process_group_mesh.py | 30 ++ .../test_gpt2_qkv_fused_linear_1d.py | 27 +- .../test_layer/test_linear_1d.py | 42 +- .../test_layer/test_sequence_parallel.py | 178 ++++++++ tests/test_shardformer/test_model/_utils.py | 40 +- .../test_model/test_shard_bert.py | 23 +- .../test_model/test_shard_bloom.py | 22 + .../test_model/test_shard_chatglm2.py | 22 + .../test_model/test_shard_gpt2.py | 22 + .../test_model/test_shard_llama.py | 101 +++++ 33 files changed, 1630 insertions(+), 256 deletions(-) create mode 100644 tests/test_shardformer/test_layer/test_sequence_parallel.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index eba7d1c1f..29cec7cfd 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -34,7 +34,8 @@ from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase -DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 +DP_AXIS, PP_AXIS, TP_AXIS, SP_AXIS = 0, 1, 2, 3 +SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} @@ -53,6 +54,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): shard_config: ShardConfig, dp_group: ProcessGroup, tp_group: ProcessGroup, + sp_group: ProcessGroup, use_ddp: bool, ddp_config: dict, custom_policy: Policy, @@ -61,6 +63,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): self.shard_config = shard_config self.dp_group = dp_group self.tp_group = tp_group + self.sp_group = sp_group self.use_dpp = use_ddp self.require_grad_sync = True @@ -168,13 +171,24 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): Returns: None """ - if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism: + + if self.shard_config.enable_sequence_parallelism: + if self.shard_config.sequence_parallelism_mode == "all_to_all": + return + + if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: + # If sequence parallelism is enabled and mode is split_gather or ring, gradients are synchronized + # across the tensor parallelism group. + group = self.tp_group + else: + raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}") + if grads is not None: # Synchronize provided gradient tensors across the tensor parallelism group. - SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, grads=grads) + SeqParallelUtils.allreduce_partial_data_grad(process_group=group, grads=grads) else: # Synchronize gradients from the model across the tensor parallelism group. - SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, model=self.module) + SeqParallelUtils.allreduce_partial_data_grad(process_group=group, model=self.module) def forward(self, *args, **kwargs): if self.convert_fn is not None: @@ -727,10 +741,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): # Get all working gradients and gradients to be synchronized. all_working_grads = _get_all_working_grads() grads_to_sync = _get_grads_to_sync(all_working_grads) - if self.require_grad_sync and grads_to_sync is not None: # Synchronize sequence parallelism gradients if required. - SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_pg, grads=grads_to_sync) + SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync) else: return @@ -891,6 +904,7 @@ class HybridParallelPlugin(PipelinePluginBase): Args: tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. + sp_size (int): The size of sequence parallelism. precision (str, optional): Specifies the precision of parameters during training. Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. Defaults to 'fp16'. @@ -903,6 +917,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. + sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather". enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. @@ -938,6 +953,7 @@ class HybridParallelPlugin(PipelinePluginBase): self, tp_size: int, pp_size: int, + sp_size: int = None, precision: str = "fp16", zero_stage: int = 0, enable_all_optimization: bool = False, @@ -945,6 +961,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_flash_attention: bool = False, enable_jit_fused: bool = False, enable_sequence_parallelism: bool = False, + sequence_parallelism_mode: str = None, enable_sequence_overlap: bool = False, parallel_output: bool = True, num_microbatches: Optional[int] = None, @@ -976,14 +993,41 @@ class HybridParallelPlugin(PipelinePluginBase): super().__init__() assert ( dist.get_world_size() % (tp_size * pp_size) == 0 - ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" if enable_sequence_parallelism: - assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism" + self.sequence_parallelism_mode = sequence_parallelism_mode if sequence_parallelism_mode is not None else "1" + assert ( + self.sequence_parallelism_mode in SUPPORT_SP_MODE + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" + if self.sequence_parallelism_mode in ["split_gather", "ring"]: + assert ( + tp_size > 1 + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" + if sp_size != 1: + warnings.warn( + f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." + ) + self.sp_size = 1 + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + elif self.sequence_parallelism_mode in ["all_to_all"]: + assert ( + tp_size == 1 + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with tensor parallelism" + assert ( + pp_size == 1 + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with pipeline parallelism" + self.sp_size = dist.get_world_size() if sp_size is None else sp_size + self.dp_size = dist.get_world_size() // (self.sp_size * pp_size) + else: + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + assert ( + sp_size == 1 or sp_size is None + ), f"sp_size can only be set to a >1 number when enable_sequence_parallelism is True" + self.sp_size = 1 self.tp_size = tp_size self.pp_size = pp_size - self.dp_size = dist.get_world_size() // (tp_size * pp_size) self.precision = precision self.zero_stage = zero_stage self.cpu_offload = cpu_offload @@ -992,7 +1036,7 @@ class HybridParallelPlugin(PipelinePluginBase): self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism - self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) self.stage_manager = None self.schedule = None self.custom_policy = custom_policy @@ -1033,9 +1077,14 @@ class HybridParallelPlugin(PipelinePluginBase): self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) + if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]: + self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + else: + self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS) self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, + sequence_parallel_process_group=self.sp_group, pipeline_stage_manager=self.stage_manager, enable_tensor_parallelism=self.tp_size > 1, enable_all_optimization=self.enable_all_optimization, @@ -1043,6 +1092,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_flash_attention=self.enable_flash_attention, enable_jit_fused=self.enable_jit_fused, enable_sequence_parallelism=enable_sequence_parallelism, + sequence_parallelism_mode=sequence_parallelism_mode, enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, gradient_checkpoint_config=gradient_checkpoint_config, @@ -1113,13 +1163,23 @@ class HybridParallelPlugin(PipelinePluginBase): ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: param_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): - use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 + use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( + self.dp_size == 1 + and self.pp_size == 1 + and self.enable_sequence_parallelism + and self.sequence_parallelism_mode == "all_to_all" + ) + if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": + dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS]) + else: + dp_group = self.dp_group model = HybridParallelModule( model, precision=self.precision, shard_config=self.shard_config, - dp_group=self.dp_group, + dp_group=dp_group, tp_group=self.tp_group, + sp_group=self.sp_group, use_ddp=use_ddp, ddp_config=self.ddp_config, custom_policy=self.custom_policy, @@ -1149,7 +1209,8 @@ class HybridParallelPlugin(PipelinePluginBase): tp_process_group=self.tp_group, ) else: - if self.dp_size == 1: + zero_dp_size = dist.get_world_size(dp_group) + if zero_dp_size == 1: warnings.warn( "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "If you are not intended to use cpu_offload, please consider set zero_stage=0." @@ -1161,7 +1222,7 @@ class HybridParallelPlugin(PipelinePluginBase): model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info, - dp_process_group=self.dp_group, + dp_process_group=dp_group, tp_process_group=self.tp_group, pp_process_group=self.pp_group, verbose=True, diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index ae372dd03..83888e506 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -254,6 +254,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) + # TODO: Currently moe only support partially sequence parallel + self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, @@ -365,6 +368,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): shard_config=self.shard_config, dp_group=self.dp_group, tp_group=self.tp_group, + sp_group=self.sp_group, use_ddp=use_ddp, ddp_config=self.ddp_config, custom_policy=self.custom_policy, diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index ae3956c69..ccf122695 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -161,7 +161,7 @@ class ProcessGroupMesh: @staticmethod def get_coords_along_axis( - base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int] + base_coord: Tuple[int, ...], axis: Union[int, List[int]], indices_at_axis: Union[List[int], List[List[int]]] ) -> List[Tuple[int, ...]]: """Get coordinates along the given axis. @@ -173,13 +173,28 @@ class ProcessGroupMesh: Returns: List[Tuple[int, ...]]: Coordinates along the axis. """ - coords_in_group = [] - for idx in indices_at_axis: - coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :]) + if isinstance(axis, int): + axis = [axis,] + assert isinstance(indices_at_axis[0], int) + indices_at_axis = [indices_at_axis,] + + def add_index(base_coord, axis, indices_at_axis): + coords_in_group = [] + for idx in indices_at_axis: + coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :]) + return coords_in_group + + coords_in_group = [base_coord] + for ax, indices_at_ax in zip(axis, indices_at_axis): + new_coords_in_group = [] + for coords in coords_in_group: + new_coords_in_group += add_index(coords, ax, indices_at_ax) + coords_in_group = new_coords_in_group + return coords_in_group def create_group_along_axis( - self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None + self, axis: Union[int, List[int]], indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, backend: Optional[str] = None ) -> ProcessGroup: """Create all process groups along the given axis, and return the one which the current process belongs to. @@ -191,10 +206,17 @@ class ProcessGroupMesh: Returns: ProcessGroup: The process group along the given axis which the current process belongs to. """ - indices_at_axis = indices_at_axis or list(range(self._shape[axis])) + if isinstance(axis, int): + axis = [axis,] + if indices_at_axis is not None: + assert isinstance(indices_at_axis[0], int) + indices_at_axis = [indices_at_axis,] + + indices_at_axis = indices_at_axis or [list(range(self._shape[ax])) for ax in axis] reduced_shape = list(self._shape) # the choices on the axis are reduced to 1, since it's determined by `indices_at_axis` - reduced_shape[axis] = 1 + for ax in axis: + reduced_shape[ax] = 1 target_group = None # use Cartesian product to generate all combinations of coordinates for base_coord in itertools.product(*[range(s) for s in reduced_shape]): @@ -225,4 +247,3 @@ class ProcessGroupMesh: # no need to cache it explicitly, since it will be cached in `create_group_along_axis` return self.create_group_along_axis(axis, indices_at_axis, backend=backend) return self._ranks_to_group[ranks_in_group] - \ No newline at end of file diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index c9b4317a6..0e368dbf9 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,4 +1,5 @@ from .attn import AttnMaskType, ColoAttention +from ._operation import all_to_all_comm from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row @@ -26,4 +27,5 @@ __all__ = [ "ParallelModule", "AttnMaskType", "ColoAttention", + "all_to_all_comm", ] diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 241770901..82d37bb4c 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -167,6 +167,97 @@ class LinearWithAsyncCommunication(torch.autograd.Function): return grad_input, grad_weight, grad_bias, None, None, None +def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False): + # currently only support one single tensor as output + group_size = dist.get_world_size(process_group) + cur_rank = dist.get_rank(process_group) + + # output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)] + + # initialization of ring communication + recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0 + send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1 + rank_map = list(dist.get_process_group_ranks(process_group)) + recv_rank = rank_map[recv_rank] + send_rank = rank_map[send_rank] + recv_tensors = {} + send_tensors = {} + for k, v in input_to_gather.items(): + recv_tensors[k] = torch.empty_like(v) + send_tensors[k] = v.clone() + + def communicate_step(): + comm_ops = [] + for k in recv_tensors: + comm_ops.append(dist.P2POp(dist.irecv, recv_tensors[k], recv_rank, group=process_group)) + comm_ops.append(dist.P2POp(dist.isend, send_tensors[k], send_rank, group=process_group)) + return dist.batch_isend_irecv(comm_ops) + + def switch_step(): + for k in recv_tensors: + send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k] + + output_tensors = [] + + handles = communicate_step() + # first round: special case, retrive from local tensor + output_tensors.append(func(**input_to_gather, **input_local)) + for i in range(group_size - 2): + for handle in handles: + handle.wait() + + switch_step() + + handles = communicate_step() + + # actual computation + output_tensors.append(func(**send_tensors, **input_local)) + + # final round: special case, no need to send/recv again + for handle in handles: + handle.wait() + output_tensors.append(func(**recv_tensors, **input_local)) + + return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim) + + +class _GatherForwardReduceScatterBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward. + + """ + + @staticmethod + def forward(ctx, input_, process_group, dim): + ctx.process_group = process_group + ctx.dim = dim + + return _gather(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + dim = ctx.dim + process_group = ctx.process_group + + # do reduce-scatter + new_shape = list(grad_output.shape) + assert ( + new_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " + new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) + grad_list = [ + item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device) + dist.reduce_scatter(output, grad_list, group=process_group) + + return output, None, None + + class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): """Gather input from sequence parallel in forward and reduce-scatter gradient in backward @@ -178,7 +269,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group @@ -186,12 +277,25 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): ctx.dim = dim ctx.overlap = overlap - input_parallel = _gather(input_, dim, process_group) + if ring is True: + input_to_gather = {"input": input_} + input_local = {"weight": weight} - if bias is not None: - output = F.linear(input_parallel, weight, bias) + output = _ring_as_gather( + F.linear, + input_to_gather=input_to_gather, + input_local=input_local, + process_group=process_group, + ) + + if bias is not None: + output += bias else: - output = F.linear(input_parallel, weight) + input_parallel = _gather(input_, dim, process_group) + if bias is not None: + output = F.linear(input_parallel, weight, bias) + else: + output = F.linear(input_parallel, weight) return output @@ -294,11 +398,146 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): # wait until reduce-scatter finished reducescatter_handle.wait() - return output, grad_weight, grad_bias, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None + + +def _ring_as_reducescatter( + func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1 +): + # currently only support one single tensor as output + group_size = dist.get_world_size(process_group) + cur_rank = dist.get_rank(process_group) + + # initialization of ring communication + recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1 + send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0 + rank_map = list(dist.get_process_group_ranks(process_group)) + recv_rank = rank_map[recv_rank] + send_rank = rank_map[send_rank] + input_tensors = [] + for _ in range(group_size): + input_tensors.append({}) + for k, v in input_to_reducescatter.items(): + input_shape = v.shape + assert input_shape[reducescatter_dim] % group_size == 0 + _input_tensors = list(torch.split(v, input_shape[reducescatter_dim] // group_size, dim=reducescatter_dim)) + for i in range(group_size): + input_tensors[i][k] = _input_tensors[i] + input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank] + input_tensors.reverse() + + output_tensor = func(**input_tensors[0], **input_local) + recv_tensor = torch.empty_like(output_tensor) + send_tensor = output_tensor.clone() + + def communicate_step(): + recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) + send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) + return dist.batch_isend_irecv([recv_op, send_op]) + + handles = communicate_step() + # first round: special case, retrive from local tensor + for i in range(group_size - 2): + # actual computation + output_tensor = func(**input_tensors[i + 1], **input_local) + + for handle in handles: + handle.wait() + output_tensor += recv_tensor + + tmp_tensor = send_tensor + send_tensor = output_tensor + output_tensor = tmp_tensor + + handles = communicate_step() + + # final round: special case, no need to send/recv again + output_tensor = func(**input_tensors[-1], **input_local) + for handle in handles: + handle.wait() + output_tensor += recv_tensor + return output_tensor class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): - """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + """Reduce-scatter input from sequence parallel in forward and gather gradient in backward with ring + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward. + + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, dim, ring): + ctx.save_for_backward(input_, weight, bias) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.dim = dim + + if ring is True: + input_to_reducescatter = {"input": input_} + input_local = {"weight": weight} + + if bias is not None: + input_to_reducescatter["bias"] = bias + + output = _ring_as_reducescatter( + F.linear, + input_to_reducescatter=input_to_reducescatter, + input_local=input_local, + process_group=process_group, + ) + else: + if bias is not None: + partial_output = F.linear(input_, weight, bias) + else: + partial_output = F.linear(input_, weight) + + output_shape = list(partial_output.shape) + assert ( + output_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " + output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group) + + output_list = [ + item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous() + dist.reduce_scatter(output, output_list, group=process_group) + + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight, bias = ctx.saved_tensors + use_bias = ctx.use_bias + dim = ctx.dim + process_group = ctx.process_group + + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm + if use_bias: + bias = bias.view(bias.shape) + + grad_output = _gather(grad_output, dim, process_group) + + # TODO Need to fully optimize + total_input = input_ + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + return grad_input, grad_weight, grad_bias, None, None, None + + +class _ReduceScatterForwardGatherBackward(torch.autograd.Function): + """Reduce-scatter input from sequence parallel in forward and gather gradient in backward Args: input_ (`torch.Tensor`): The input tensor from sequence parallel region. @@ -343,7 +582,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group @@ -351,9 +590,24 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): ctx.dim = dim ctx.overlap = overlap - input_parallel = _gather(input_, dim, process_group) + if ring is True: + input_to_gather = {} + input_local = {} + input_to_gather["input"] = input_ + input_local["other"] = weight - output = torch.matmul(input_parallel, weight) + output = _ring_as_gather( + torch.matmul, + input_to_gather=input_to_gather, + input_local=input_local, + process_group=process_group, + gather_dim=dim, + ) + + else: + input_parallel = _gather(input_, dim, process_group) + + output = torch.matmul(input_parallel, weight) if bias is not None: output = output + bias @@ -433,7 +687,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): # wait until reduce-scatter finished reducescatter_handle.wait() - return output, grad_weight, grad_bias, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None class _SplitForwardGatherBackward(torch.autograd.Function): @@ -448,14 +702,17 @@ class _SplitForwardGatherBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, dim, process_group): + def forward(ctx, input_, dim, process_group, grad_scale=None): ctx.process_group = process_group ctx.dim = dim + ctx.grad_scale = grad_scale return _split(input_, dim, process_group) @staticmethod def backward(ctx, grad_output): - return _gather(grad_output, ctx.dim, ctx.process_group), None, None + if ctx.grad_scale is not None: + grad_output = grad_output * ctx.grad_scale + return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None class _ReduceForward(torch.autograd.Function): @@ -505,14 +762,50 @@ class _GatherForwardSplitBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, dim, process_group): + def forward(ctx, input_, dim, process_group, grad_scale=None): ctx.process_group = process_group ctx.dim = dim + ctx.grad_scale = grad_scale return _gather(input_, dim, process_group) @staticmethod def backward(ctx, grad_output): - return _split(grad_output, ctx.dim, ctx.process_group), None, None + if ctx.grad_scale is not None: + grad_output = grad_output * ctx.grad_scale + return _split(grad_output, ctx.dim, ctx.process_group), None, None, None + + +class _AllToAll(torch.autograd.Function): + """All-to-all communication. + + Args: + input_: input matrix + process_group: communication group + scatter_dim: scatter dimension + gather_dim: gather dimension + """ + + @staticmethod + def forward(ctx, input_, process_group, scatter_dim, gather_dim): + ctx.process_group = process_group + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + world_size = dist.get_world_size(process_group) + bsz, _, _ = input_.shape + + # using all_to_all_single when batch size is 1 + if bsz == 1: + return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim) + else: + return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim) + + @staticmethod + def backward(ctx, *grad_output): + process_group = ctx.process_group + scatter_dim = ctx.gather_dim + gather_dim = ctx.scatter_dim + return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) + return (return_grad, None, None, None) class HookParameter(torch.autograd.Function): @@ -608,6 +901,40 @@ def _reduce_scatter(input_, dim=1, process_group=None): return output +def _all_to_all(input_, world_size, group, scatter_dim, gather_dim): + input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=group) + return torch.cat(output_list, dim=gather_dim).contiguous() + + +def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): + inp_shape = list(input_.shape) + inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size + if scatter_dim < 2: + input_t = input_.reshape([seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]).contiguous() + else: + input_t = ( + input_.reshape([-1, seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]) + .transpose(0, 1) + .contiguous() + ) + + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) + + if scatter_dim < 2: + output = output.transpose(0, 1).contiguous() + + return output.reshape( + inp_shape[:gather_dim] + + [ + inp_shape[gather_dim] * seq_world_size, + ] + + inp_shape[gather_dim + 1 :] + ).contiguous() + + def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) @@ -617,31 +944,39 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre def linear_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False ): return _LinearWithGatherForwardReduceScatterBackward.apply( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring ) -def linear_reducescatter_forward_gather_backward(input_, process_group, dim): - return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim) +def gather_forward_reducescatter_backward(input_, process_group, dim): + return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim) + + +def reducescatter_forward_gather_backward(input_, process_group, dim): + return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim) + + +def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False): + return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring) def matmul_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False ): return _MatmulWithGatherForwardReduceScatterBackward.apply( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring ) -def gather_forward_split_backward(input_, dim, process_group): - return _GatherForwardSplitBackward.apply(input_, dim, process_group) +def gather_forward_split_backward(input_, dim, process_group, grad_scale=None): + return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale) -def split_forward_gather_backward(input_, dim, process_group): - return _SplitForwardGatherBackward.apply(input_, dim, process_group) +def split_forward_gather_backward(input_, dim, process_group, grad_scale=None): + return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale) def reduce_forward(input_, process_group): @@ -650,3 +985,7 @@ def reduce_forward(input_, process_group): def reduce_backward(input_, process_group): return _ReduceBackward.apply(input_, process_group) + + +def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): + return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index eeb0ef399..7c8619ad8 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -23,11 +23,13 @@ from colossalai.tensor.d_tensor.api import ( ) from ._operation import ( + gather_forward_reducescatter_backward, gather_forward_split_backward, linear_gather_forward_reducescatter_backward, linear_reducescatter_forward_gather_backward, linear_with_async_comm, reduce_forward, + reducescatter_forward_gather_backward, split_forward_gather_backward, ) from .parallel_module import ParallelModule @@ -74,7 +76,7 @@ class Linear1D_Col(ParallelModule): device: torch.device = None, process_group: ProcessGroup = None, gather_output: bool = False, - seq_parallel: bool = False, + seq_parallel_mode: str = None, seq_parallel_dim: int = 1, overlap: torch.cuda.Stream = None, skip_bias_add: bool = False, @@ -89,7 +91,7 @@ class Linear1D_Col(ParallelModule): self.in_features = in_features self.out_features = out_features self.gather_output = gather_output - self.seq_parallel = seq_parallel + self.seq_parallel_mode = seq_parallel_mode self.seq_parallel_dim = seq_parallel_dim self.overlap = overlap self.skip_bias_add = skip_bias_add @@ -196,12 +198,18 @@ class Linear1D_Col(ParallelModule): # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - if self.seq_parallel: - output_parallel = linear_gather_forward_reducescatter_backward( - input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap - ) - else: + + if self.seq_parallel_mode is None: output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + elif self.seq_parallel_mode == "split_gather": + input_parallel = gather_forward_reducescatter_backward( + input_parallel, self.process_group, self.seq_parallel_dim + ) + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False) + elif self.seq_parallel_mode == "ring": + output_parallel = linear_gather_forward_reducescatter_backward( + input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True + ) if self.gather_output: # All-gather across the partitions. @@ -225,7 +233,8 @@ class Linear1D_Row(ParallelModule): dtype (`torch.dtype`): The dtype of parameters, defaults to None. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. - seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None. + seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False weight_initializer (:class:`typing.Callable`, optional): @@ -245,7 +254,7 @@ class Linear1D_Row(ParallelModule): dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, - seq_parallel: bool = False, + seq_parallel_mode: str = None, seq_parallel_dim: int = 1, parallel_input: bool = True, skip_bias_add: bool = False, @@ -265,7 +274,7 @@ class Linear1D_Row(ParallelModule): self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group - self.seq_parallel = seq_parallel + self.seq_parallel_mode = seq_parallel_mode self.seq_parallel_dim = seq_parallel_dim self.num_partitions = dist.get_world_size(self.process_group) @@ -403,18 +412,26 @@ class Linear1D_Row(ParallelModule): output_parallel_list[i], group=self.process_group, async_op=True ) handle_list.append(handle) - # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) for handle in handle_list: handle.wait() output = torch.cat(output_parallel_list, dim=-1) else: - output_parallel = linear_with_async_comm(input_, self.weight, None, None, False) - if self.seq_parallel: - output = linear_reducescatter_forward_gather_backward( + if self.seq_parallel_mode is None: + output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) + output = reduce_forward(output_parallel, self.process_group) + elif self.seq_parallel_mode == "split_gather": + output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) + output = reducescatter_forward_gather_backward( output_parallel, self.process_group, self.seq_parallel_dim ) - else: - output = reduce_forward(output_parallel, self.process_group) + elif self.seq_parallel_mode == "ring": + output = linear_reducescatter_forward_gather_backward( + input_, + self.weight, + process_group=self.process_group, + dim=self.seq_parallel_dim, + ring=True, + ) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 12476d050..dc3634238 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -25,12 +25,12 @@ from colossalai.tensor.d_tensor.api import ( from ._operation import ( gather_forward_split_backward, - linear_reducescatter_forward_gather_backward, linear_with_async_comm, matmul_gather_forward_reducescatter_backward, matmul_with_async_comm, reduce_backward, reduce_forward, + reducescatter_forward_gather_backward, split_forward_gather_backward, ) from .parallel_module import ParallelModule @@ -150,7 +150,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): device (`torch.device`): The device of parameters, defaults to None. n_fused (int): The number items fused, defaults to 3 (QKV). process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. - seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None. gather_output (bool, optional): If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is :math:`Y_i = XA_i`, defaults to False @@ -175,7 +175,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): process_group: ProcessGroup = None, async_communication: bool = False, gather_output: bool = False, - seq_parallel: bool = False, + seq_parallel_mode: str = None, overlap: bool = False, skip_bias_add: bool = False, n_fused: int = 3, @@ -190,7 +190,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): self.in_features = in_features self.out_features = out_features self.gather_output = gather_output - self.seq_parallel = seq_parallel + self.seq_parallel_mode = seq_parallel_mode self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device @@ -312,17 +312,22 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - if self.seq_parallel: - input_parallel = input_ - output_parallel = matmul_gather_forward_reducescatter_backward( - input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap - ) - else: + if self.seq_parallel_mode is None: # Set up backprop all-reduce. input_parallel = reduce_backward(input_, self.process_group) output_parallel = matmul_with_async_comm( input_parallel, self.weight, bias, self.process_group, self.async_communication ) + elif self.seq_parallel_mode == "split_gather": + input_parallel = input_ + output_parallel = matmul_gather_forward_reducescatter_backward( + input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap + ) + elif self.seq_parallel_mode == "ring": + input_parallel = input_ + output_parallel = matmul_gather_forward_reducescatter_backward( + input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap, True + ) if self.gather_output: # All-gather across the partitions. @@ -347,7 +352,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): dtype (`torch.dtype`): The dtype of parameters, defaults to None. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, - seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None. which is preserved for kernel fusion, defaults to False weight_initializer (:class:`typing.Callable`, optional): The initializer of weight, defaults to kaiming uniform initializer. @@ -366,7 +371,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, - seq_parallel: bool = False, + seq_parallel_mode: str = None, parallel_input: bool = True, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -385,7 +390,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group - self.seq_parallel = seq_parallel + self.seq_parallel_mode = seq_parallel_mode self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: @@ -528,11 +533,15 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): handle.wait() output = torch.cat(output_parallel_list, dim=-1) else: - output_parallel = torch.matmul(input_, self.weight) - if self.seq_parallel: - output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) - else: + if self.seq_parallel_mode is None: + output_parallel = torch.matmul(input_, self.weight) output = reduce_forward(output_parallel, self.process_group) + elif self.seq_parallel_mode == "split_gather": + output_parallel = torch.matmul(input_, self.weight) + output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + elif self.seq_parallel_mode == "ring": + output_parallel = torch.matmul(input_, self.weight) + output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) if not self.skip_bias_add: if self.bias is not None: @@ -702,7 +711,6 @@ class FusedLinear1D_Col(ParallelModule): # process_group=process_group, # is_transposed=False) # linear_1d.bias.data.copy_(sharded_bias.data) - print(linear_1d.weight.shape) return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 0d2cc1b33..9c6ced445 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -35,17 +35,21 @@ class SeqParallelUtils: return getattr(param, "partial_derived", False) @staticmethod - def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None): + def allreduce_partial_data_grad( + process_group: ProcessGroup, + model: nn.Module = None, + grads: List[torch.Tensor] = None, + ): """ Allreduce partial derived gradients across the specified process group. This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism. Args: - tp_group (ProcessGroup): The process group for gradient synchronization. + process_group (ProcessGroup): The process group for gradient synchronization. model (nn.Module): The model from which gradients will be synchronized. grads (List[torch.Tensor]): The list of gradients to be synchronized. - + only_sp_partial (bool): Whether handle all the parameters or only parameters marked as partial derived. Raises: AssertionError: If both `model` and `grads` are provided or neither is provided. """ @@ -53,22 +57,26 @@ class SeqParallelUtils: assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None." # Get the size of the process group, which determines whether synchronization is needed. - tp_size = get_world_size(tp_group) if tp_group is not None else 1 + group_size = get_world_size(process_group) if process_group is not None else 1 - if tp_size == 1: + if group_size == 1: # If the process group size is 1, no synchronization is required. return if model is not None: # If `model` is provided, extract partial derived gradients from the model's parameters. grads = [] + for p in model.parameters(): - if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p): - grads.append(p.grad.data) + if p.grad is not None: + if SeqParallelUtils.is_sp_partial_derived_param(p): + grads.append(p.grad.data) # Flatten and reduce the gradients using the specified process group. + if len(grads) == 0: + return coalesced = _flatten_dense_tensors(grads) - dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group) + dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group) # Unflatten the synchronized gradients and update the model's gradients. for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): @@ -76,7 +84,7 @@ class SeqParallelUtils: else: # If `grads` are provided explicitly, synchronize those gradients directly. coalesced = _flatten_dense_tensors(grads) - dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group) + dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group) for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 7411e1d0e..0838fcee6 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -186,13 +186,14 @@ class BertPipelineForwards: # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config is not None and shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group - ) - if encoder_hidden_states is not None: - encoder_hidden_states = split_forward_gather_backward( - encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group ) + if encoder_hidden_states is not None: + encoder_hidden_states = split_forward_gather_backward( + encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): if stage_manager.is_first_stage() and idx == 0: @@ -240,9 +241,10 @@ class BertPipelineForwards: # When sequence parallelism done, gather the output tensor in forward and split it in backward if shard_config is not None and shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group - ) + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index d94c30d29..fe70376e1 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -213,10 +213,11 @@ class BloomPipelineForwards: # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - if shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group - ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) start_idx, end_idx = stage_index[0], stage_index[1] for i, (block, layer_past) in enumerate( @@ -261,10 +262,11 @@ class BloomPipelineForwards: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) # When sequence parallelism done, gather the output tensor in forward and split it in backward - if shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group - ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) if stage_manager.is_last_stage(): # Add last hidden state diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index a3e000e6e..9207b34d0 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -191,12 +191,11 @@ class ChatGLMPipelineForwards: all_hidden_states = () if output_hidden_states else None start_idx, end_idx = stage_index[0], stage_index[1] - if shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward( - hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group, - ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = split_forward_gather_backward( + hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + ) for idx in range(start_idx, end_idx): layer = self.encoder._get_layer(idx) if output_hidden_states: @@ -222,12 +221,11 @@ class ChatGLMPipelineForwards: if use_cache: presents = presents + (kv_cache,) - if shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward( - hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group, - ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = gather_forward_split_backward( + hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index ea22cfb15..1306c8aa6 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -218,12 +218,13 @@ class GPT2PipelineForwards: # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - if shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = split_forward_gather_backward( + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + ) # Going through held blocks. start_idx, end_idx = stage_index[0], stage_index[1] @@ -278,12 +279,13 @@ class GPT2PipelineForwards: all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) # When sequence parallelism done, gather the output tensor in forward and split it in backward - if shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = gather_forward_split_backward( + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + ) if stage_manager.is_last_stage(): hidden_states = self.ln_f(hidden_states) @@ -1141,7 +1143,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states = split_forward_gather_backward( hidden_states, dim=1, - process_group=shard_config.tensor_parallel_process_group, + process_group=shard_config.sequence_parallel_process_group, ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -1208,7 +1210,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states = gather_forward_split_backward( hidden_states, dim=1, - process_group=shard_config.tensor_parallel_process_group, + process_group=shard_config.sequence_parallel_process_group, ) hidden_states = self.ln_f(hidden_states) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index eb421c92b..0f1b4ad0a 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -1,18 +1,32 @@ +import math import warnings from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) -from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel +from transformers.models.llama.modeling_llama import ( + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaModel, + apply_rotary_pos_emb, + repeat_kv, +) from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import ( + all_to_all_comm, + gather_forward_split_backward, + split_forward_gather_backward, +) from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, cross_entropy_1d @@ -438,7 +452,7 @@ class LlamaPipelineForwards: return {"hidden_states": hidden_states} -def get_llama_flash_attention_forward(shard_config: ShardConfig): +def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb llama_version = 2 @@ -459,18 +473,30 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() + + if sp_mode in ["split_gather", "ring"]: + q_len *= sp_size assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + query_states = all_to_all_comm(query_states, sp_group) + key_states = all_to_all_comm(key_states, sp_group) + value_states = all_to_all_comm(value_states, sp_group) + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -490,6 +516,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value @@ -726,3 +755,261 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): ) return forward + + +def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + # sp: modify sp_len when sequence parallel mode is ring + if sp_mode in ["split_gather", "ring"]: + q_len *= sp_size + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + query_states = all_to_all_comm(query_states, sp_group) + key_states = all_to_all_comm(key_states, sp_group) + value_states = all_to_all_comm(value_states, sp_group) + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + else: + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value + + return forward + + +def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): + logger = logging.get_logger(__name__) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + # modify past_key_values_length when using sequence parallel + past_key_values_length *= sp_size + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if sp_mode in ["ring", "split_gather"]: + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + ) + + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if sp_mode == "ring" or sp_mode == "split_gather": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + return forward diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index cd7bdcdd6..0a61d8cff 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List @@ -66,8 +67,17 @@ class BertPolicy(Policy): else: norm_cls = col_nn.LayerNorm - use_sequence_parallel = self.shard_config.enable_sequence_parallelism + sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for Bert" + if sp_mode == "ring": + warnings.warn( + f"For Bert, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" + ) + sp_mode = "split_gather" + overlap = self.shard_config.enable_sequence_overlap + sp_partial_derived = sp_mode == "split_gather" + if self.shard_config.enable_tensor_parallelism: policy[BertLayer] = ModulePolicyDescription( attribute_replacement={ @@ -85,7 +95,7 @@ class BertPolicy(Policy): suffix="attention.self.query", target_module=col_nn.Linear1D_Col, kwargs={ - "seq_parallel": use_sequence_parallel, + "seq_parallel_mode": sp_mode, "overlap": overlap, }, ), @@ -93,7 +103,7 @@ class BertPolicy(Policy): suffix="attention.self.key", target_module=col_nn.Linear1D_Col, kwargs={ - "seq_parallel": use_sequence_parallel, + "seq_parallel_mode": sp_mode, "overlap": overlap, }, ), @@ -101,7 +111,7 @@ class BertPolicy(Policy): suffix="attention.self.value", target_module=col_nn.Linear1D_Col, kwargs={ - "seq_parallel": use_sequence_parallel, + "seq_parallel_mode": sp_mode, "overlap": overlap, }, ), @@ -112,7 +122,7 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="attention.output.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={"seq_parallel_mode": sp_mode}, ), SubModuleReplacementDescription( suffix="attention.output.dropout", @@ -122,14 +132,14 @@ class BertPolicy(Policy): suffix="intermediate.dense", target_module=col_nn.Linear1D_Col, kwargs={ - "seq_parallel": use_sequence_parallel, + "seq_parallel_mode": sp_mode, "overlap": overlap, }, ), SubModuleReplacementDescription( suffix="output.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={"seq_parallel_mode": sp_mode}, ), SubModuleReplacementDescription( suffix="output.dropout", @@ -151,7 +161,7 @@ class BertPolicy(Policy): ] ) - if use_sequence_parallel: + if sp_mode == "split_gather": self.append_or_create_method_replacement( description={"forward": bert_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, @@ -165,12 +175,12 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="attention.output.LayerNorm", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="output.LayerNorm", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 55b69d5f0..2becadc3f 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List @@ -55,8 +56,18 @@ class BloomPolicy(Policy): norm_cls = col_nn.FusedLayerNorm else: norm_cls = col_nn.LayerNorm - use_sequence_parallel = self.shard_config.enable_sequence_parallelism + + sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for BLOOM" + if sp_mode == "ring": + warnings.warn( + f"For BLOOM, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" + ) + sp_mode = "split_gather" + overlap = self.shard_config.enable_sequence_overlap + sp_partial_derived = sp_mode == "split_gather" + if self.shard_config.enable_tensor_parallelism: policy[BloomBlock] = ModulePolicyDescription( attribute_replacement={ @@ -70,12 +81,12 @@ class BloomPolicy(Policy): SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap}, ), SubModuleReplacementDescription( suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={"seq_parallel_mode": sp_mode}, ), SubModuleReplacementDescription( suffix="self_attention.attention_dropout", @@ -84,12 +95,12 @@ class BloomPolicy(Policy): SubModuleReplacementDescription( suffix="mlp.dense_h_to_4h", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap}, ), SubModuleReplacementDescription( suffix="mlp.dense_4h_to_h", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={"seq_parallel_mode": sp_mode}, ), ], ) @@ -132,19 +143,19 @@ class BloomPolicy(Policy): SubModuleReplacementDescription( suffix="input_layernorm", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="post_attention_layernorm", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, target_key=BloomBlock, ) - if use_sequence_parallel: + if sp_mode == "split_gather": self.append_or_create_method_replacement( description={"forward": get_bloom_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 0830d85f1..dabc14bff 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -55,8 +56,17 @@ class ChatGLMPolicy(Policy): norm_cls = col_nn.RMSNorm else: norm_cls = col_nn.LayerNorm - use_sequence_parallel = self.shard_config.enable_sequence_parallelism + + sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for ChatGLM2" + if sp_mode == "ring": + warnings.warn( + f"For ChatGLM2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" + ) + sp_mode = "split_gather" overlap = self.shard_config.enable_sequence_overlap + sp_partial_derived = sp_mode == "split_gather" + if self.shard_config.enable_tensor_parallelism: policy[ChatGLMModel] = ModulePolicyDescription( attribute_replacement={}, @@ -91,12 +101,12 @@ class ChatGLMPolicy(Policy): SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0, "overlap": overlap}, + kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0, "overlap": overlap}, ), SubModuleReplacementDescription( suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0}, + kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0}, ), SubModuleReplacementDescription( suffix="self_attention.core_attention.attention_dropout", @@ -110,12 +120,12 @@ class ChatGLMPolicy(Policy): SubModuleReplacementDescription( suffix="input_layernorm", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="post_attention_layernorm", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, @@ -145,7 +155,7 @@ class ChatGLMPolicy(Policy): ) # use sequence parallel - if use_sequence_parallel: + if sp_mode == "split_gather": self.append_or_create_method_replacement( description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 4bcac3951..380a432dc 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List @@ -50,8 +51,25 @@ class GPT2Policy(Policy): norm_cls = col_nn.FusedLayerNorm else: norm_cls = col_nn.LayerNorm - use_sequence_parallel = self.shard_config.enable_sequence_parallelism + + sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for GPT2" + if sp_mode == "ring": + warnings.warn( + f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" + ) + sp_mode = "split_gather" overlap = self.shard_config.enable_sequence_overlap + sp_partial_derived = sp_mode in ["split_gather", "ring"] + use_flash_attention = self.shard_config.enable_flash_attention + # todo: currently sp cannot be used with flashattention + if sp_mode in ["split_gather", "ring", "all_to_all"]: + if use_flash_attention: + warnings.warn( + f"Sequence parallelism mode {sp_mode} cannot be used with FlashAttention, will disable FlashAttention automatically." + ) + self.shard_config.enable_flash_attention = False + use_flash_attention = False if self.shard_config.enable_tensor_parallelism: policy[GPT2Model] = ModulePolicyDescription( sub_module_replacement=[ @@ -78,7 +96,7 @@ class GPT2Policy(Policy): target_module=col_nn.GPT2FusedLinearConv1D_Col, kwargs={ "n_fused": 3, - "seq_parallel": use_sequence_parallel, + "seq_parallel_mode": sp_mode, "overlap": overlap, }, ), @@ -86,7 +104,7 @@ class GPT2Policy(Policy): suffix="attn.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, kwargs={ - "seq_parallel": use_sequence_parallel, + "seq_parallel_mode": sp_mode, }, ), SubModuleReplacementDescription( @@ -94,14 +112,16 @@ class GPT2Policy(Policy): target_module=col_nn.GPT2FusedLinearConv1D_Col, kwargs={ "n_fused": 1, - "seq_parallel": use_sequence_parallel, + "seq_parallel_mode": sp_mode, "overlap": overlap, }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={ + "seq_parallel_mode": sp_mode, + }, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", @@ -133,25 +153,25 @@ class GPT2Policy(Policy): SubModuleReplacementDescription( suffix="ln_1", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="ln_2", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="ln_cross_attn", target_module=norm_cls, ignore_if_not_exist=True, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, target_key=GPT2Block, ) - if self.shard_config.enable_flash_attention: + if use_flash_attention: self.append_or_create_method_replacement( description={ "forward": get_gpt2_flash_attention_forward(), @@ -164,7 +184,7 @@ class GPT2Policy(Policy): "forward": get_gpt_model_forward_for_flash_attn(self.shard_config) } - if self.shard_config.enable_sequence_parallelism: + if sp_mode is not None: policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} return policy diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 18d79f84a..bb4551b2c 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -12,6 +12,8 @@ from ..modeling.llama import ( LlamaPipelineForwards, get_llama_flash_attention_forward, get_llama_model_forward_for_flash_attn, + get_llama_seq_parallel_attention_forward, + get_llama_seq_parallel_model_forward, get_lm_forward_with_dist_cross_entropy, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -45,9 +47,74 @@ class LlamaPolicy(Policy): else: norm_cls = RMSNorm - if self.shard_config.enable_sequence_parallelism: + if self.pipeline_stage_manager is not None: self.shard_config.enable_sequence_parallelism = False - warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + self.shard_config.enable_sequence_overlap = False + self.shard_config.sequence_parallelism_mode = None + warnings.warn( + f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False" + ) + sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None + sp_group = ( + self.shard_config.sequence_parallel_process_group if self.shard_config.enable_sequence_parallelism else None + ) + sp_partial_derived = sp_mode in ["split_gather", "ring"] + + use_flash_attention = self.shard_config.enable_flash_attention + # Currently sp cannot to be used with flashattention + if sp_mode in ["split_gather", "ring", "all_to_all"]: + if use_flash_attention: + warnings.warn( + f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will disable FlashAttention automatically." + ) + use_flash_attention = False + + if sp_mode in ["split_gather", "ring"]: + self.append_or_create_method_replacement( + description={ + "forward": get_llama_seq_parallel_model_forward( + sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group + ), + }, + policy=policy, + target_key=LlamaModel, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=LlamaAttention, + ) + elif sp_mode == "all_to_all": + decoder_attribute_replacement = { + "num_heads": self.model.config.num_attention_heads // sp_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size + + policy[LlamaAttention] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=LlamaAttention, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_llama_seq_parallel_model_forward( + sp_mode=sp_mode, + sp_size=sp_size, + sp_group=sp_group, + ), + }, + policy=policy, + target_key=LlamaModel, + ) if self.shard_config.enable_tensor_parallelism: decoder_attribute_replacement = { @@ -65,30 +132,37 @@ class LlamaPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, + kwargs=dict(seq_parallel_mode=sp_mode), ), ], ) @@ -108,10 +182,12 @@ class LlamaPolicy(Policy): SubModuleReplacementDescription( suffix="input_layernorm", target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="post_attention_layernorm", target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, @@ -122,16 +198,17 @@ class LlamaPolicy(Policy): description=SubModuleReplacementDescription( suffix="norm", target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, ), policy=policy, target_key=LlamaModel, ) # use flash attention - if self.shard_config.enable_flash_attention: + if use_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_llama_flash_attention_forward(self.shard_config), + "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size), }, policy=policy, target_key=LlamaAttention, @@ -243,7 +320,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy): policy = super().module_policy() - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism and not self.shard_config.enable_sequence_parallelism: # add a new item for casual lm new_item = { LlamaForCausalLM: ModulePolicyDescription( diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index ce78a7e94..7489873c2 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -1,3 +1,4 @@ +import warnings from dataclasses import dataclass, field from typing import Any, Dict, Optional @@ -9,6 +10,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from .grad_ckpt_config import GradientCheckpointConfig __all__ = ["ShardConfig"] +SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] @dataclass @@ -29,13 +31,15 @@ class ShardConfig: enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. """ tensor_parallel_process_group: Optional[ProcessGroup] = None + sequence_parallel_process_group: Optional[ProcessGroup] = None pipeline_stage_manager: Optional[PipelineStageManager] = None enable_tensor_parallelism: bool = True + enable_all_optimization: bool = False enable_fused_normalization: bool = False enable_flash_attention: bool = False enable_jit_fused: bool = False - enable_all_optimization: bool = False enable_sequence_parallelism: bool = False + sequence_parallelism_mode: str = None enable_sequence_overlap: bool = False parallel_output: bool = True gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None @@ -50,22 +54,57 @@ class ShardConfig: def tensor_parallel_size(self): return self._tensor_parallel_size + @property + def sequence_parallel_size(self): + return self._sequence_parallel_size + def __post_init__(self): - if not self.enable_tensor_parallelism and self.enable_sequence_parallelism: - raise ValueError( - "enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True" - ) - if not self.enable_sequence_parallelism and self.enable_sequence_overlap: - raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True") - if not self.enable_tensor_parallelism: - self._tensor_parallel_size = 1 - else: - # get the parallel size - self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) # turn on all optimization if all_optimization is set to True if self.enable_all_optimization: self._turn_on_all_optimization() + if self.enable_sequence_parallelism: + self.sequence_parallelism_mode = ( + "split_gather" if self.sequence_parallelism_mode is None else self.sequence_parallelism_mode + ) + assert ( + self.sequence_parallelism_mode in SUPPORT_SP_MODE + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" + if self.sequence_parallelism_mode in ["split_gather", "ring"]: + assert ( + self.enable_tensor_parallelism + ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True" + elif self.sequence_parallelism_mode in ["all_to_all"]: + assert ( + not self.enable_tensor_parallelism + ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False" + if self.enable_sequence_overlap: + self.enable_sequence_overlap = False + warnings.warn( + f"The enable_sequence_overlap flag will be ignored in sequence parallelism mode {self.sequence_parallelism_mode}" + ) + else: + if self.sequence_parallelism_mode: + self.sequence_parallelism_mode = None + warnings.warn( + f"The sequence_parallelism_mode will be ignored when enable_sequence_parallelism is False" + ) + assert ( + not self.enable_sequence_overlap + ), f"enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True" + + # get the tensor parallel size + if not self.enable_tensor_parallelism: + self._tensor_parallel_size = 1 + else: + self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + + # get the sequence parallel size + if not self.enable_sequence_parallelism: + self._sequence_parallel_size = 1 + else: + self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group) + def _turn_on_all_optimization(self): """ Turn on all optimization. diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index a2433d1b2..bbbaf13b5 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -79,6 +79,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): master_weights: bool = True, # master weights ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) + self._dtype = self.optim.param_groups[0]["params"][0].dtype self._logger = get_dist_logger() self._verbose = verbose @@ -494,7 +495,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # clear reduced grads if self._overlap_communication: get_accelerator().synchronize() - self.zero_grad() def backward_by_grad(self, tensor, grad): diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 24f9627c2..ab5d97420 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -18,8 +18,23 @@ def data_gen(): # tokenized_input = tokenizer(input, return_tensors='pt') # input_ids = tokenized_input['input_ids'] # attention_mask = tokenized_input['attention_mask'] - input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + # input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) + # attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + input_ids = torch.tensor( + [ + [15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779], + [15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779], + ], + dtype=torch.int64, + ) + attention_mask = torch.tensor( + [ + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ], + dtype=torch.int64, + ) + return dict(input_ids=input_ids, attention_mask=attention_mask) @@ -35,9 +50,9 @@ def data_gen_for_question_answering(): # question answering data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - start_positions = torch.tensor([0], dtype=torch.int64) + start_positions = torch.tensor([[0], [0]], dtype=torch.int64) data["start_positions"] = start_positions - end_positions = torch.tensor([1], dtype=torch.int64) + end_positions = torch.tensor([[1], [1]], dtype=torch.int64) data["end_positions"] = end_positions return data @@ -46,14 +61,20 @@ def data_gen_for_token_classification(): # token classification data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64) + data["labels"] = torch.tensor( + [ + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1], + ], + dtype=torch.int64, + ) return data def data_gen_for_sequence_classification(): # sequence classification data gen data = data_gen() - data["labels"] = torch.tensor([1], dtype=torch.int64) + data["labels"] = torch.tensor([[1], [1]], dtype=torch.int64) return data @@ -61,12 +82,18 @@ def date_gen_for_double_heads(): num_choices = 2 batch_size = 2 input_ids = torch.tensor( - [[15496, 11, 616, 3290, 318, 13779, 318, 13779], [15496, 11, 616, 3290, 318, 13779, 318, 13779]], + [ + [15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779], + [15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779], + ], + dtype=torch.int64, + ) + attention_mask = torch.tensor( + [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64, ) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) - mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64) + mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64) mc_token_ids = torch.arange(0, num_choices, dtype=torch.int64) mc_token_ids = mc_token_ids.expand((batch_size, num_choices)) multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, num_choices, -1).contiguous() @@ -103,6 +130,7 @@ config = transformers.GPT2Config( hidden_dropout=0, problem_type="single_label_classification", pad_token_id=50256, + tie_word_embeddings=True, ) config_for_token_classification = copy.deepcopy(config) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 9f801e0cc..58b5b0487 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -28,9 +28,19 @@ if HAS_LLAMA: # ----------------------------------- input_ids = torch.Tensor( - [[1, 15043, 29892, 590, 11203, 338, 274, 1082], [1, 15043, 29892, 590, 11203, 338, 274, 1082]] + [ + [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], + [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], + ] ).long() - attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long() + + attention_mask = torch.Tensor( + [ + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ] + ).long() + return dict(input_ids=input_ids, attention_mask=attention_mask) # label is needed for casual lm diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index ece3b4036..ac6f8caef 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -44,7 +44,10 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) bert_model = model_fn() - enable_all_optimization = True if tp_size > 1 else False + + enable_flash_attention = True if tp_size > 1 else False + enable_fused_normalization = True if tp_size > 1 else False + enable_jit_fused = True if tp_size > 1 else False with shared_tempdir() as tempdir: pretrained_path = os.path.join(tempdir, "pretrained") @@ -54,7 +57,9 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b plugin = GeminiPlugin( **placement_config, tp_size=tp_size, - enable_all_optimization=enable_all_optimization, + enable_flash_attention=enable_flash_attention, + enable_fused_normalization=enable_fused_normalization, + enable_jit_fused=enable_jit_fused, extra_dp_size=extra_dp_size, ) booster = Booster(plugin=plugin) @@ -80,7 +85,9 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() - enable_all_optimization = True if tp_size > 1 else False + enable_flash_attention = True if tp_size > 1 else False + enable_fused_normalization = True if tp_size > 1 else False + enable_jit_fused = True if tp_size > 1 else False extra_dp_size = dist.get_world_size() // (zero_size * tp_size) plugin = GeminiPlugin( **placement_config, @@ -88,7 +95,9 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, - enable_all_optimization=enable_all_optimization, + enable_flash_attention=enable_flash_attention, + enable_fused_normalization=enable_fused_normalization, + enable_jit_fused=enable_jit_fused, ) booster = Booster(plugin=plugin) diff --git a/tests/test_cluster/test_process_group_mesh.py b/tests/test_cluster/test_process_group_mesh.py index 08542d1f6..3d206622d 100644 --- a/tests/test_cluster/test_process_group_mesh.py +++ b/tests/test_cluster/test_process_group_mesh.py @@ -84,6 +84,30 @@ def check_process_group_mesh_with_cases(): 2: [2], 3: [3], } + TPxPP_RANKS_IN_GROUP = { + 0: [0, 1, 2, 3], + 1: [0, 1, 2, 3], + 2: [0, 1, 2, 3], + 3: [0, 1, 2, 3], + } + DPxTP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + TPxPP_PARTIAL_INDICES = { + 0: [[0, 1], [0]], + 1: [[1], [0, 1]], + 2: [[0], [0, 1]], + 3: [[0, 1], [1]], + } + TPxPP_RANKS_IN_GROUP_PARTIAL = { + 0: [0, 1], + 1: [1, 3], + 2: [0, 2], + 3: [2, 3], + } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE, TP_SIZE) @@ -107,6 +131,12 @@ def check_process_group_mesh_with_cases(): assert pg_mesh.get_ranks_in_group(pp_group) == PP_RANKS_IN_GROUP[rank] dp_group = pg_mesh.get_group_along_axis(DP_DIM) assert pg_mesh.get_ranks_in_group(dp_group) == DP_RANKS_IN_GROUP[rank] + dpxtp_group = pg_mesh.create_group_along_axis([DP_DIM, TP_DIM]) + assert pg_mesh.get_ranks_in_group(dpxtp_group) == DPxTP_RANKS_IN_GROUP[rank] + tpxpp_group = pg_mesh.create_group_along_axis([TP_DIM, PP_DIM]) + assert pg_mesh.get_ranks_in_group(tpxpp_group) == TPxPP_RANKS_IN_GROUP[rank] + tpxpp_group_partial = pg_mesh.create_group_along_axis([TP_DIM, PP_DIM], TPxPP_PARTIAL_INDICES[rank]) + assert pg_mesh.get_ranks_in_group(tpxpp_group_partial) == TPxPP_RANKS_IN_GROUP_PARTIAL[rank] # check prev rank if RANK_TO_COORDINATE[rank][TP_DIM] != 0: diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index e056860ed..e9aa0dbed 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -56,13 +56,18 @@ def rearrange(tensor: torch.Tensor, dim: int): return rearanged_tensor -def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): +def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module( - linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, n_fused=3, overlap=overlap + linear_copy, + process_group=None, + gather_output=True, + seq_parallel_mode=seq_parallel_mode, + n_fused=3, + overlap=overlap, ) assert linear.weight.shape == torch.Size([48, 192]) @@ -79,7 +84,9 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool) # check computation correctness x = torch.rand(1, 4, 48).cuda() out = linear(x) - x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + x_for_shard = ( + x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + ) gather_out = linear_conv_col(x_for_shard) assert_close(rearrange(out, -1), gather_out) @@ -91,14 +98,14 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool) assert_close(target_grad, linear_conv_col.weight.grad) -def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): +def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() linear_row = GPT2FusedLinearConv1D_Row.from_native_module( - linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel + linear_copy, process_group=None, parallel_input=False, seq_parallel_mode=seq_parallel_mode ) assert linear.weight.shape == torch.Size([48, 192]) @@ -115,7 +122,7 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): x = torch.rand(1, 4, 48).cuda() out = linear(x) gather_out = linear_row(x) - target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] + target_out = out if seq_parallel_mode is None else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] assert_close(target_out, gather_out) # check backward correctness @@ -128,11 +135,11 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): @parameterize("lazy_init", [False, True]) -@parameterize("seq_parallel", [False, True]) +@parameterize("seq_parallel_mode", ["split_gather", None]) @parameterize("overlap", [True]) -def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool): - check_linear_conv_1d_col(lazy_init, seq_parallel, overlap) - check_linear_conv_1d_row(lazy_init, seq_parallel) +def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool, overlap: bool): + check_linear_conv_1d_col(lazy_init, seq_parallel_mode, overlap) + check_linear_conv_1d_row(lazy_init, seq_parallel_mode) def run_dist(rank, world_size, port): diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index defa4afb9..21d3190de 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -15,13 +15,13 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" -def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): +def check_linear_1d_col(lazy_init: bool, seq_parallel_mode: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() linear_col = Linear1D_Col.from_native_module( - linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, overlap=overlap + linear_copy, process_group=None, gather_output=True, seq_parallel_mode=seq_parallel_mode, overlap=overlap ) # ensure that the parameters are distributed @@ -43,7 +43,9 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) - x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + x_for_shard = ( + x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + ) x_for_shard.requires_grad_(True) out = linear(x_for_unshard) @@ -63,20 +65,20 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): assert x_for_unshard.grad is not None target_unshard_gard = ( x_for_unshard.grad - if seq_parallel is False + if seq_parallel_mode is None else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] ) assert_close(target_unshard_gard, x_for_shard.grad) -def check_linear_1d_row(lazy_init: bool, seq_parallel: bool): +def check_linear_1d_row(lazy_init: bool, seq_parallel_mode: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() linear_row = Linear1D_Row.from_native_module( - linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel + linear_copy, process_group=None, parallel_input=False, seq_parallel_mode=seq_parallel_mode ) assert linear_row.weight.shape == torch.Size([128, 16]) @@ -98,7 +100,7 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool): # run forward out = linear(x_for_unshard) gather_out = linear_row(x_for_shard) - target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] + target_out = out if seq_parallel_mode is None else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] assert_close(target_out, gather_out) # check backward correctness @@ -115,7 +117,7 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool): assert_close(x_for_unshard.grad, x_for_shard.grad) -def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool): +def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear_1 = nn.Linear(32, 128).cuda() @@ -125,10 +127,10 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool linear_1_copy = nn.Linear(32, 128).cuda() linear_2_copy = nn.Linear(128, 32).cuda() linear_col = Linear1D_Col.from_native_module( - linear_1_copy, process_group=None, gather_output=False, seq_parallel=seq_parallel, overlap=overlap + linear_1_copy, process_group=None, gather_output=False, seq_parallel_mode=seq_parallel_mode, overlap=overlap ) linear_row = Linear1D_Row.from_native_module( - linear_2_copy, process_group=None, parallel_input=True, seq_parallel=seq_parallel + linear_2_copy, process_group=None, parallel_input=True, seq_parallel_mode=seq_parallel_mode ) linear_1.load_state_dict(linear_col.state_dict()) @@ -141,13 +143,17 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) - x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + x_for_shard = ( + x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + ) x_for_shard.requires_grad_(True) # run forward unshard_out = linear_2(linear_1(x_for_unshard)) shard_out = linear_row(linear_col(x_for_shard)) - target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()] + target_out = ( + unshard_out if seq_parallel_mode is None else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()] + ) assert_close(target_out, shard_out) # check backward correctness @@ -163,19 +169,19 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool assert x_for_unshard.grad is not None target_unshard_gard = ( x_for_unshard.grad - if seq_parallel is False + if seq_parallel_mode is None else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] ) assert_close(target_unshard_gard, x_for_shard.grad) @parameterize("lazy_init", [False, True]) -@parameterize("seq_parallel", [False, True]) +@parameterize("seq_parallel_mode", [None, "split_gather"]) @parameterize("overlap", [True]) -def run_dist_linear_test(lazy_init, seq_parallel, overlap): - check_linear_1d_col(lazy_init, seq_parallel, overlap) - check_linear_1d_row(lazy_init, seq_parallel) - check_linear_col_plus_row(lazy_init, seq_parallel, overlap) +def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap): + check_linear_1d_col(lazy_init, seq_parallel_mode, overlap) + check_linear_1d_row(lazy_init, seq_parallel_mode) + check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap) def check_dist_linear(rank, world_size, port): diff --git a/tests/test_shardformer/test_layer/test_sequence_parallel.py b/tests/test_shardformer/test_layer/test_sequence_parallel.py new file mode 100644 index 000000000..13b1a13e7 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_sequence_parallel.py @@ -0,0 +1,178 @@ +import copy + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer import all_to_all_comm +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +class SequenceParallelAttention(torch.nn.Module): + """Initialization. + + Arguments: + local_attention (Module): local attention with q,k,v + sequence_process_group (ProcessGroup): sequence parallel process group + scatter_idx (int): scatter_idx for all2all comm + gather_idx (int): gather_idx for all2all comm + """ + + def __init__( + self, + heads_num: torch.Tensor, + hidden_dim: torch.Tensor, + enable_sequence_parallellism: bool = False, + sequence_process_group: dist.ProcessGroup = None, + scatter_idx: int = 2, + gather_idx: int = 1, + ) -> None: + super(SequenceParallelAttention, self).__init__() + self.spg = sequence_process_group + self.scatter_idx = scatter_idx + self.gather_idx = gather_idx + self.heads_num = heads_num + self.hidden_dim = hidden_dim + assert hidden_dim % heads_num == 0 + self.head_dim = hidden_dim // heads_num + self.enable_sequence_parallellism = enable_sequence_parallellism + + self.q = nn.Linear(hidden_dim, hidden_dim) + self.k = nn.Linear(hidden_dim, hidden_dim) + self.v = nn.Linear(hidden_dim, hidden_dim) + self.out = nn.Linear(hidden_dim, hidden_dim) + + def attn(self, q, k, v): + batch_size, seq_len = q.shape[0], q.shape[1] + + scale = self.head_dim**0.5 + qk = torch.matmul(q, k.transpose(-2, -1)) / scale + weights = F.softmax(qk, dim=-1) + + attention_score = torch.matmul(weights, v) + + return attention_score + + def forward(self, x) -> Tensor: + bsz, q_len, _ = x.size() + + seq_len = q_len * dist.get_world_size(self.spg) if self.enable_sequence_parallellism else q_len + num_heads = ( + self.heads_num // dist.get_world_size(self.spg) if self.enable_sequence_parallellism else self.heads_num + ) + + # in shape : e.g., [s/p:h:] + query_states = self.q(x) + key_states = self.k(x) + value_states = self.v(x) + + if self.enable_sequence_parallellism: + query_states = all_to_all_comm(query_states, self.spg, self.scatter_idx, self.gather_idx) + key_states = all_to_all_comm(key_states, self.spg, self.scatter_idx, self.gather_idx) + value_states = all_to_all_comm(value_states, self.spg, self.scatter_idx, self.gather_idx) + + query_states = query_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2) + # out shape : e.g., [s:h/p:] + attn_score = self.attn(query_states, key_states, value_states) + attn_score = attn_score.transpose(1, 2).contiguous() + attn_score = attn_score.reshape(bsz, seq_len, num_heads * self.head_dim) + if self.enable_sequence_parallellism: + attn_score = all_to_all_comm(attn_score, self.spg, self.gather_idx, self.scatter_idx) + + # output e.g., [s/p::h] + output = self.out(attn_score) + + return output + + +def seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size): + seq_len = seq_len + hidden_dim = hidden_dim + head_num = head_num + batch_size = batch_size + world_size = dist.get_world_size() + + x = torch.randn(batch_size, seq_len, hidden_dim).cuda() + x_unshard = x.clone() + x_unshard.requires_grad_(True) + x_input = torch.chunk(x.clone(), world_size, dim=1)[dist.get_rank()] + x_input.requires_grad_(True) + + # Multi-head Attention + mha = SequenceParallelAttention(head_num, hidden_dim).cuda() + # Multi-head Attention forward + mha_out = mha(x_unshard) + + # Sequence parallel Attention + sp_attn = SequenceParallelAttention(head_num, hidden_dim, True).cuda() + sp_attn.load_state_dict(copy.deepcopy(mha.state_dict())) + # Sequence parallel Attention forward + dist_attn_out = sp_attn(x_input) + + # gather the output of sequence parallel attention + out_list = [torch.empty_like(dist_attn_out) for _ in range(world_size)] + dist.all_gather(out_list, dist_attn_out) + seq_out = torch.cat(out_list, dim=1) + + # forward result check + assert_close(seq_out, mha_out) + + # Multi-head Attention backward + mha_out.sum().backward() + q_grad = mha.q.weight.grad + k_grad = mha.k.weight.grad + v_grad = mha.v.weight.grad + o_grad = mha.out.weight.grad + x_grad = x_unshard.grad + + # Sequence parallel Attention backward + dist_attn_out.sum().backward() + q_grad_seq = sp_attn.q.weight.grad + k_grad_seq = sp_attn.k.weight.grad + v_grad_seq = sp_attn.v.weight.grad + o_grad_seq = sp_attn.out.weight.grad + x_grad_seq = x_input.grad + # all_reduce the grad of sequence parallel attention weight + dist.all_reduce(q_grad_seq) + dist.all_reduce(k_grad_seq) + dist.all_reduce(v_grad_seq) + dist.all_reduce(o_grad_seq) + # gather the grad of sequence parallel attention input + x_grad_seq_list = [torch.empty_like(x_grad_seq) for _ in range(world_size)] + dist.all_gather(x_grad_seq_list, x_grad_seq) + x_grad_seq_gather = torch.cat(x_grad_seq_list, dim=1) + + # backward result check + assert_close(q_grad_seq, q_grad) + assert_close(k_grad_seq, k_grad) + assert_close(v_grad_seq, v_grad, atol=1e-4, rtol=1e-4) + assert_close(o_grad_seq, o_grad) + assert_close(x_grad_seq_gather, x_grad) + + +@parameterize("seq_len", [128]) +@parameterize("hidden_dim", [64]) +@parameterize("head_num", [4]) +@parameterize("batch_size", [1]) +def run_seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size): + seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size) + + +def check_all2all_attn(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_seq_parallel_attn() + + +@rerun_if_address_is_in_use() +def test_all_to_all_attention(): + spawn(check_all2all_attn, nprocs=4) + + +if __name__ == "__main__": + test_all_to_all_attention() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 85be9a242..d5fc2c30f 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,5 +1,4 @@ import copy -import math from contextlib import nullcontext from typing import Any, Callable, Dict, List, Optional @@ -123,7 +122,6 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c sharded_model = copy.deepcopy(org_model) if use_lazy_init: ctx.materialize(org_model) - org_model = org_model.cuda() org_optimizer = Adam(org_model.parameters(), lr=1e-3) sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3) @@ -162,24 +160,22 @@ def run_forward_backward_with_hybrid_plugin( data = data_gen_fn() - if booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.tp_size != 0: - seq_len = data["input_ids"].shape[-1] - lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) - times = lcm // seq_len - input_shape = data["input_ids"].shape - for k, v in data.items(): - if v.shape == input_shape: - data[k] = v.repeat((1,) * (v.dim() - 1) + (times,)) + shard_test_data = {} + for k, v in data.items(): + shard_test_data[k] = data[k].clone() + unshard_test_data = {} + for k, v in data.items(): + unshard_test_data[k] = data[k].clone() sharded_model.train() if booster.plugin.stage_manager is not None: - for k, v in data.items(): + for k, v in shard_test_data.items(): if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: new_shape = [1] * v.dim() new_shape[0] = 4 - data[k] = v.to("cuda").repeat(*new_shape) + shard_test_data[k] = v.to("cuda").repeat(*new_shape) - data_iter = iter([data]) + data_iter = iter([shard_test_data]) sharded_output = booster.execute_pipeline( data_iter, sharded_model, @@ -189,17 +185,22 @@ def run_forward_backward_with_hybrid_plugin( return_outputs=True, ) sharded_loss = sharded_output["loss"] - else: - data = {k: v.cuda() for k, v in data.items()} - sharded_output = sharded_model(**data) + else: + shard_test_data = {k: v.cuda() for k, v in shard_test_data.items()} + sharded_output = sharded_model(**shard_test_data) sharded_loss = criterion(sharded_output) sharded_optimizer.backward(sharded_loss) org_model.train() - data = {k: v.cuda() for k, v in data.items()} - org_output = org_model(**data) - + if booster.plugin.stage_manager is not None: + for k, v in unshard_test_data.items(): + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + unshard_test_data[k] = v.to("cuda").repeat(*new_shape) + unshard_test_data = {k: v.cuda() for k, v in unshard_test_data.items()} + org_output = org_model(**unshard_test_data) org_loss = criterion(org_output) org_loss.backward() @@ -212,7 +213,6 @@ def check_output_hidden_state( stage_manager: Optional[PipelineStageManager] = None, atol: float = 1e-5, rtol: float = 1e-3, - dim: int = 0, ): org_hidden_state = org_output.last_hidden_state diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 768bd95bd..919557797 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -100,6 +100,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 1, @@ -154,7 +176,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) def run_bert_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index b70cba8b4..cc0786618 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -99,6 +99,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index 78d752b69..405ceba32 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -135,6 +135,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index d59d7e4ad..4aac7f3d4 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -131,6 +131,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 55858cbd4..27f904292 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -2,6 +2,8 @@ import os import pytest import torch +import torch.distributed as dist +from torch.testing import assert_close import colossalai from colossalai.logging import disable_existing_loggers @@ -46,6 +48,26 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] col_layer_for_check = ["layers[0].self_attn.o_proj"] + # Here we check the grad of layernorm because an all-reduce operation should be performed during sequence parallelism + norm_layer_for_check = ["layers[0].input_layernorm", "layers[0].post_attention_layernorm"] + + # During pipeline parallelism, we cannot get the grad of norm layer during first stage, so we only check this when pp is not enbaled + if stage_manager is None: + norm_layer_for_check.append("norm") + + # Check the grad when using ZeRO-1 and ZeRO-2 + if ( + booster.plugin.zero_stage in [1, 2] + and booster.plugin.shard_config.enable_sequence_parallelism + and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" + ): + for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): + working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)] + grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p)) + grad_index = 0 if sharded_optimizer._partition_grads else sharded_optimizer._local_rank + grad = grads[grad_index] + sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] + assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} @@ -60,8 +82,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, col_layer_grads = get_grad_tensors_for_check( llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False ) + norm_layer_grads = get_grad_tensors_for_check( + llama_model, + shard_llama_model, + norm_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) + grads_to_check.update(norm_layer_grads) # optimizer executes step org_optimizer.step() @@ -98,6 +131,74 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 2, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2,