From 1a3315e33611a63c5aa2e2d507f1d51c8be0c9d2 Mon Sep 17 00:00:00 2001 From: littsk <1214689160@qq.com> Date: Fri, 3 Nov 2023 13:32:43 +0800 Subject: [PATCH] [hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926) * [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915) * Add layer norm gradients all-reduce for sequence parallel. * skip pipeline inference test * [hotfix] fixing polices of sequence parallel (#4922) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy --------- Co-authored-by: littsk <1214689160@qq.com> * Hotfix/add grad all reduce for sequence parallel (#4927) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy * fix bug using wrong variables --------- Co-authored-by: littsk <1214689160@qq.com> * fix policy initialization * fix bloom and chatglm policices * polish code of handling layernorm * fix moe module * polish code of class initializing --------- Co-authored-by: Zhongkai Zhao --- .../booster/plugin/hybrid_parallel_plugin.py | 372 ++++++++++++++++-- .../plugin/moe_hybrid_parallel_plugin.py | 9 +- .../inference/tensor_parallel/engine.py | 4 +- colossalai/shardformer/README.md | 18 + colossalai/shardformer/layer/__init__.py | 5 +- colossalai/shardformer/layer/normalization.py | 154 +++++++- colossalai/shardformer/layer/utils.py | 76 +++- .../shardformer/policies/auto_policy.py | 7 +- .../shardformer/policies/base_policy.py | 6 +- colossalai/shardformer/policies/bert.py | 88 ++--- colossalai/shardformer/policies/blip2.py | 170 ++++---- colossalai/shardformer/policies/bloom.py | 70 ++-- colossalai/shardformer/policies/chatglm2.py | 83 ++-- colossalai/shardformer/policies/gpt2.py | 79 ++-- colossalai/shardformer/policies/llama.py | 56 +-- colossalai/shardformer/policies/opt.py | 54 ++- colossalai/shardformer/policies/sam.py | 104 ++--- colossalai/shardformer/policies/t5.py | 78 ++-- colossalai/shardformer/policies/vit.py | 6 - colossalai/shardformer/policies/whisper.py | 119 +++--- colossalai/shardformer/shard/sharder.py | 4 +- colossalai/zero/low_level/low_level_optim.py | 4 +- tests/test_infer/test_pipeline_infer.py | 3 +- .../test_amp_optimizer.py | 16 +- .../test_naive_optimizer.py | 8 +- .../test_zero_optimizer.py | 16 +- .../test_model/test_shard_bert.py | 21 + .../test_model/test_shard_bloom.py | 14 + .../test_model/test_shard_chatglm2.py | 14 + .../test_model/test_shard_gpt2.py | 14 + 30 files changed, 1120 insertions(+), 552 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 72c3ec46a..f9716ab97 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,6 +1,6 @@ import ctypes import random -from contextlib import nullcontext +from contextlib import contextmanager from functools import partial from types import MethodType from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union @@ -25,6 +25,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.policies.base_policy import Policy from colossalai.tensor.d_tensor.api import is_distributed_tensor from colossalai.zero.low_level import LowLevelZeroOptimizer @@ -47,12 +48,17 @@ class HybridParallelModule(ModelWrapper): precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, + tp_group: ProcessGroup, use_ddp: bool, ddp_config: dict, custom_policy: Policy, ) -> None: self.stage_manager = shard_config.pipeline_stage_manager + self.shard_config = shard_config self.dp_group = dp_group + self.tp_group = tp_group + self.use_dpp = use_ddp + self.require_grad_sync = True shardformer = ShardFormer(shard_config) if custom_policy is not None: @@ -98,19 +104,75 @@ class HybridParallelModule(ModelWrapper): dist.all_reduce(param.grad, group=group) dist.barrier() - def no_sync(self) -> Iterator[None]: - # no sync grads across data parallel - return nullcontext() + @contextmanager + def no_sync(self): + r""" + A context manager to disable automatic gradient synchronization (all-reduce) and allow manual synchronization + when 'no_sync' is active. Alternatively, synchronization will occur in the first forward-backward pass + when exiting the context. + """ - def sync_grads(self): - # sync grad across data parallel + # Store the current value of 'require_grad_sync' to restore it later. + old_require_grad_sync = self.require_grad_sync + # Disable automatic gradient synchronization. + self.require_grad_sync = False + try: + if self.use_dpp: + # If using data parallel processing (use_dpp), disable synchronization too. + with self.module.no_sync(): + yield + else: + yield + finally: + # Restore the original value of 'require_grad_sync'. + self.require_grad_sync = old_require_grad_sync + + def sync_dp_grads(self): + r""" + Synchronize gradients across data parallelism (DP) if the DP group size is greater than 1. + This function performs an all-reduce operation to combine gradients from different devices in the DP group. + + Args: + None + + Returns: + None + """ + + # Check if the DP group size is 1, meaning no synchronization is needed. if self.dp_group.size() == 1: return + + # Iterate through the model's parameters and perform gradient synchronization. for p in self.module.parameters(): if p.grad is not None: + # Perform all-reduce to combine gradients from different devices. dist.all_reduce(p.grad, group=self.dp_group) + # Normalize the gradient by dividing it by the DP group size. p.grad.div_(self.dp_group.size()) + def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None): + r""" + Synchronize gradients that are partially derived within sequence parallelism + if sequence parallelism is enabled. Gradients can be provided explicitly or extracted + from the module. + + Args: + grads (Optional[List[torch.Tensor]]): A list of gradient tensors to synchronize. If not + provided, gradients will be extracted from the model. + + Returns: + None + """ + + if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism: + if grads is not None: + # Synchronize provided gradient tensors across the tensor parallelism group. + SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, grads=grads) + else: + # Synchronize gradients from the model across the tensor parallelism group. + SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, model=self.module) + def forward(self, *args, **kwargs): if self.convert_fn is not None: args = tree_map(self.convert_fn, args) @@ -166,7 +228,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): def __init__( self, optim: Optimizer, - model: Module, + model: HybridParallelModule, use_pipeline: bool, param_info: OrderedDict, max_norm: float = 0, @@ -176,13 +238,69 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optim, model) + self.model = model self.stage_manager = model.stage_manager self.shared_params = model.shared_params self.max_norm = max_norm self.tp_pg = tp_process_group self.pp_pg = pp_process_group + self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 + self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 super().__init__(optim) + def backward(self, loss: Tensor, *args, **kwargs): + r""" + Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. + + This method performs backward pass for gradient computation. If sequence parallelism is enabled + and gradient synchronization is required, it will synchronize gradients that are partially derived + within sequence parallelism across tp parallelism groups. + + Args: + loss (Tensor): The loss tensor to compute gradients with respect to. + *args: Additional positional arguments to be passed to the superclass backward method. + **kwargs: Additional keyword arguments to be passed to the superclass backward method. + + Returns: + None + """ + + # Call the superclass backward method to compute gradients. + super().backward(loss, *args, **kwargs) + + if self.model.require_grad_sync: + # If gradient synchronization is required, sync sequence parallelism gradients. + self.model.sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def backward_by_grad(self, tensor: Tensor, grad: Tensor): + """ + Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. + + This method performs a backward pass for gradient computation using a precomputed gradient tensor. + If sequence parallelism is enabled and gradient synchronization is required, it will synchronize + gradients that are partially derived within sequence parallelism across tp parallelism groups. + + Args: + tensor (Tensor): The input tensor for which gradients are computed. + grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor. + + Returns: + None + """ + + # Call the superclass backward method to compute gradients. + super().backward_by_grad(tensor, grad) + + if self.model.require_grad_sync: + # If gradient synchronization is required, sync sequence parallelism gradients. + self.model.sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + def step(self, *args, **kwargs): r""" Perform an optimization step. @@ -220,8 +338,6 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): if len(param_gradient_pairs) == 0: return 0.0 - tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 - pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 norm_type = float(norm_type) # gradients used for norm calculation. @@ -230,9 +346,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): if norm_type == inf: total_norm = max(grad.data.abs().max() for grad in gradients) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - if tp_size > 1: + if self.tp_size > 1: dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) - if pp_size > 1: + if self.pp_size > 1: dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) total_norm = total_norm_cuda.item() else: @@ -250,16 +366,16 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. # However, we still perform the 'all_reduce' operation for the sake of good coding practices. # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' - if tp_size > 1: + if self.tp_size > 1: param_for_grad = grad_to_param_mapping[id(grad)] if not is_distributed_tensor(param_for_grad): - grad_norm_exponentiated /= tp_size + grad_norm_exponentiated /= self.tp_size # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, # it means that this parameter is used in two different pipeline stages. # To avoid redundant norm calculations, we divide the exponent of this norm by # the number of shared stages. - if pp_size > 1: + if self.pp_size > 1: for shared_param in self.shared_params: if self.stage_manager.stage in shared_param: stage_shared_param = shared_param[self.stage_manager.stage] @@ -269,10 +385,10 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)]) - if tp_size > 1: + if self.tp_size > 1: # compute norm in tp process group dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) - if pp_size > 1: + if self.pp_size > 1: # compute norm in pp process group dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) @@ -314,7 +430,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): def __init__( self, optim: Optimizer, - model: Module, + model: HybridParallelModule, use_pipeline: bool, param_info: OrderedDict, precision: str = "fp16", @@ -329,11 +445,14 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): tp_process_group: Optional[ProcessGroup] = None, # if using tp pp_process_group: Optional[ProcessGroup] = None, # if using pp ): + self.model = model self.param_info = param_info self.stage_manager = model.stage_manager self.shared_params = model.shared_params self.tp_pg = tp_process_group self.pp_pg = pp_process_group + self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 + self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 if use_pipeline: init_pipeline_optimizer(optim, model) super().__init__( @@ -349,6 +468,59 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): max_norm=max_norm, ) + def backward(self, loss: Tensor, *args, **kwargs): + r""" + Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. + + This method performs backward pass for gradient computation. If sequence parallelism is enabled + and gradient synchronization is required, it will synchronize gradients that are partially derived + within sequence parallelism across tp parallelism groups. + + Args: + loss (Tensor): The loss tensor to compute gradients with respect to. + *args: Additional positional arguments to be passed to the superclass backward method. + **kwargs: Additional keyword arguments to be passed to the superclass backward method. + + Returns: + None + """ + + # Call the superclass backward method to compute gradients. + super().backward(loss, *args, **kwargs) + + if self.model.require_grad_sync: + # If gradient synchronization is required, sync sequence parallelism gradients. + self.model.sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def backward_by_grad(self, tensor: Tensor, grad: Tensor): + """ + Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. + + This method performs a backward pass for gradient computation using a precomputed gradient tensor. + If sequence parallelism is enabled and gradient synchronization is required, it will synchronize + gradients that are partially derived within sequence parallelism across tp parallelism groups. + + Args: + tensor (Tensor): The input tensor for which gradients are computed. + grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor. + + Returns: + None + """ + + # Call the superclass backward method to compute gradients. + super().backward_by_grad(tensor, grad) + + if self.model.require_grad_sync: + # If gradient synchronization is required, sync sequence parallelism gradients. + self.model.sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: r""" Compute and return the gradient norm for gradient clipping. @@ -363,8 +535,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): if len(param_gradient_pairs) == 0: return 0.0 - tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 - pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 norm_type = float(norm_type) if norm_type == inf: @@ -374,9 +544,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - if tp_size > 1: + if self.tp_size > 1: dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) - if pp_size > 1: + if self.pp_size > 1: dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) total_norm = total_norm_cuda.item() @@ -396,16 +566,16 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. # However, we still perform the 'all_reduce' operation for the sake of good coding practices. # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' - if tp_size > 1: + if self.tp_size > 1: param_for_grad = grad_to_param_mapping[id(grad)] if not is_distributed_tensor(param_for_grad): - grad_norm_exponentiated /= tp_size + grad_norm_exponentiated /= self.tp_size # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, # it means that this parameter is used in two different pipeline stages. # To avoid redundant norm calculations, we divide the exponent of this norm by # the number of shared stages. - if pp_size > 1: + if self.pp_size > 1: for shared_param in self.shared_params: if self.stage_manager.stage in shared_param: stage_working_shared_param = shared_param[self.stage_manager.stage] @@ -416,10 +586,10 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)]) - if tp_size > 1: + if self.tp_size > 1: # compute norm in tp process group dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) - if pp_size > 1: + if self.pp_size > 1: # compute norm in pp process group dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) @@ -433,7 +603,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): def __init__( self, optimizer: Optimizer, - model: Module, + model: HybridParallelModule, use_pipeline: bool, param_info: OrderedDict, initial_scale: int = 2**16, # grad scaler config @@ -455,6 +625,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): pp_process_group: Optional[ProcessGroup] = None, # if using pp forced_dtype: Optional[torch.dtype] = None, ): + self.model = model self.param_info = param_info self.stage_manager = model.stage_manager self.shared_params = model.shared_params @@ -483,6 +654,123 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): forced_dtype=forced_dtype, ) + def sync_dp_grads(self): + r""" + Synchronize gradients in the data parallelism dimension. + + This method wraps the existing `_sync_grad` method in order to explicitly synchronize gradients + in the data parallelism dimension. It is necessary due to the introduction of new parallel dimensions, + namely tp (tensor parallelism) and pp (pipeline parallelism). This ensures better code organization + and readability. + + Args: + None + + Returns: + None + """ + + # Call the superclass `_sync_grad` method to synchronize gradients. + super()._sync_grad() + + def _sync_sp_grads(self): + r""" + Synchronize gradients that are partially derived within sequence parallelism. + + This method is responsible for synchronizing partially derived gradients across tp parallelism groups. + It identifies gradients that ara partially derived or not and synchronizes them. + If synchronization is required and gradients are found to be synchronized, + it performs the synchronization. + + Args: + None + + Returns: + None + """ + + def _get_all_working_grads() -> List[Tensor]: + """Retrieve all working gradients from different parameter groups.""" + all_working_grads = [] + for group_id in range(self.num_param_groups): + working_grads = self._grad_store.get_working_grads_by_group_id(group_id) + all_working_grads.extend(working_grads) + return all_working_grads + + def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]: + """Identify gradients to be synchronized in the sequence parallelism.""" + grads_to_sync = [] + for grad in all_working_grads: + param_id_for_grad = self._grad_store.get_param_id_for_grad(grad) + param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value + if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad): + grads_to_sync.append(grad) + + if len(grads_to_sync) > 0: + return grads_to_sync + else: + return None + + # Get all working gradients and gradients to be synchronized. + all_working_grads = _get_all_working_grads() + grads_to_sync = _get_grads_to_sync(all_working_grads) + + if self.require_grad_sync and grads_to_sync is not None: + # Synchronize sequence parallelism gradients if required. + SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_pg, grads=grads_to_sync) + else: + return + + def backward(self, loss, retain_graph=False): + """ + Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. + + This method performs the backward pass for gradient computation based on a given loss tensor. + If sequence parallelism is enabled and gradient synchronization is required, it will synchronize + gradients that are partially derived within sequence parallelism across TP parallelism groups. + + Args: + loss: The loss tensor to compute gradients with respect to. + retain_graph (bool): Whether to retain the computation graph. + + Returns: + None + """ + # Call the superclass backward method to compute gradients. + super().backward(loss, retain_graph) + + if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: + # If gradient synchronization is required, sync sequence parallelism gradients. + self._sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def backward_by_grad(self, tensor, grad): + """ + Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. + + This method performs a backward pass for gradient computation based on a precomputed gradient tensor. + If sequence parallelism is enabled and gradient synchronization is required, it will synchronize + gradients that are partially derived within sequence parallelism across TP parallelism groups. + + Args: + tensor: The input tensor for which gradients are computed. + grad: The precomputed gradient tensor to compute gradients with respect to the input tensor. + + Returns: + None + """ + # Call the superclass backward_by_grad method to compute gradients. + super().backward_by_grad(tensor, grad) + + if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: + # If gradient synchronization is required, sync sequence parallelism gradients. + self._sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: r""" Compute and return the gradient norm for gradient clipping. @@ -768,7 +1056,14 @@ class HybridParallelPlugin(PipelinePluginBase): if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 model = HybridParallelModule( - model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy + model, + precision=self.precision, + shard_config=self.shard_config, + dp_group=self.dp_group, + tp_group=self.tp_group, + use_ddp=use_ddp, + ddp_config=self.ddp_config, + custom_policy=self.custom_policy, ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: @@ -826,17 +1121,32 @@ class HybridParallelPlugin(PipelinePluginBase): return_outputs: bool = False, ) -> dict: assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" - # return loss or outputs if needed + + # Create a context for gradient synchronization based on the optimizer type. + # If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync(). + # This is to avoid redundant gradient reduction in pipeline parallelism (multiple microbatch values should be reduced once), + # so we disable it, performing manual reduction instead. ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() + with ctx: outputs = self.schedule.forward_backward_step( model, data_iter, criterion, optimizer, return_loss, return_outputs ) + + # Synchronize the grads of shared parameters of the model. model.sync_shared_params() + + # Synchronize sequence parallelism gradients of the model. + model.sync_sp_grads() + + # Check if the optimizer is a HybridParallelZeroOptimizer and synchronize data parallelism gradients if so. + # Otherwise, synchronize data parallelism gradients of the model. + # This is because these are two different forms of data parallelism. if isinstance(optimizer, HybridParallelZeroOptimizer): - optimizer.sync_grad() + optimizer.sync_dp_grads() else: - model.sync_grads() + model.sync_dp_grads() + return outputs def prepare_dataloader( diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index b67642b0d..3f0e95d39 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -338,7 +338,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 model = HybridParallelModule( - model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy + module=model, + precision=self.precision, + shard_config=self.shard_config, + dp_group=self.dp_group, + tp_group=self.tp_group, + use_ddp=use_ddp, + ddp_config=self.ddp_config, + custom_policy=self.custom_policy, ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 1c203140c..283f719e5 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -218,10 +218,10 @@ class TPInferEngine: ), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" model_name = model.__class__.__name__ assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." - + model = model.model if self.shard_config.inference_gptq else model + policy = get_autopolicy(model, shard_config=self.shard_config) - policy = get_autopolicy(model, inference_only=True) self.model, _ = shardformer.optimize(model, policy) if self.shard_config.inference_gptq: diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 63b28701e..cabd10bba 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -235,6 +235,14 @@ class SubModuleReplacementDescription: class Policy(ABC): + r""" + The base class for all the policies. For each different model, it should have a different policy class, + like BertPolicy for Bert Model or OPTPolicy for OPT model. + + Shardformer has provided many built-in sharding policies for the mainstream models. You can use the + built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`. + If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify. + """ def __init__(self) self.model = None @@ -245,6 +253,16 @@ class Policy(ABC): """ self.model = model + def set_shard_config(self, shard_config: ShardConfig) -> None: + r""" + Set shard config as an attribute of the Policy object. + Args: + shard_config (:class:`ShardConfig`): The shard config to be perform + """ + self.shard_config = shard_config + + self.config_sanity_check() + @abstractmethod def preprocess(self) -> nn.Module: """ diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index a134a2cbd..56e8b08c4 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -2,7 +2,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d -from .normalization import FusedLayerNorm, FusedRMSNorm +from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row @@ -16,6 +16,9 @@ __all__ = [ "DropoutForParallelInput", "DropoutForReplicatedInput", "cross_entropy_1d", + "BaseLayerNorm", + "LayerNorm", + "RMSNorm", "FusedLayerNorm", "FusedRMSNorm", "FusedLinear1D_Col", diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 19b973be8..413d07e87 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -1,11 +1,14 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from abc import ABC, abstractmethod import torch.nn as nn from colossalai.lazy import LazyInitContext -__all__ = ["FusedLayerNorm", "FusedRMSNorm"] +from .utils import SeqParallelUtils + +__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"] FAST_LAYERNORM_SUPPORTED_SIZE = [ 1024, @@ -35,7 +38,103 @@ FAST_LAYERNORM_SUPPORTED_SIZE = [ ] -class FusedLayerNorm: +class BaseLayerNorm(ABC): + @abstractmethod + def from_native_module(module: nn.Module, sp_partial_derived: bool = False): + """ + Convert a native PyTorch layer normalization module to a specific layer normalization module, + and optionally mark parameters for gradient aggregation. + + Args: + module (nn.Module): The native PyTorch layer normalization module to be converted. + sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. + + Returns: + nn.Module: The specific layer normalization module. + + Raises: + AssertionError: If the provided module is not an instance of the supported layer normalization type. + """ + + +class RMSNorm(BaseLayerNorm): + r""" + This is a wrapper around the RMSNorm. It is meant to be used only with the from_native_module interface. + """ + + def __init__(self) -> None: + raise NotImplementedError( + "FusedLayerNorm is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface to convert a native RMSNorm module to colossalai layer norm module." + ) + + @staticmethod + def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: + """ + Convert a native RMSNorm module to colossalai layer norm module, + and optionally mark parameters for gradient aggregation. + + Args: + module (nn.Module): The native RMSNorm module to be converted. + sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. + + Returns: + nn.Module: The RMSNorm module. + """ + + LazyInitContext.materialize(module) + + if sp_partial_derived: + # Since gradients are computed using only a subset of the data, + # aggregation of these gradients is necessary during backpropagation. + # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. + SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight) + + return module + + +class LayerNorm(BaseLayerNorm): + r""" + This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface. + """ + + def __init__(self) -> None: + raise NotImplementedError( + "LayerNorm is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module." + ) + + @staticmethod + def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: + r""" + Convert a native pytorch layer norm module to colossalai layer norm module, + and optionally marking parameters for gradient aggregation. + + Args: + module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted. + sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. + + Returns: + nn.Module: The LayerNorm module. + + Raises: + AssertionError: If the provided module is not an instance of nn.LayerNorm. + """ + assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm." + + LazyInitContext.materialize(module) + + if sp_partial_derived: + # Since gradients are computed using only a subset of the data, + # aggregation of these gradients is necessary during backpropagation. + # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. + SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight) + SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias) + + return module + + +class FusedLayerNorm(BaseLayerNorm): r""" This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. """ @@ -43,15 +142,29 @@ class FusedLayerNorm: def __init__(self) -> None: raise NotImplementedError( "FusedLayerNorm is not implemented as a physical class. " - "It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex." + "It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex." ) @staticmethod - def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: + def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: r""" - Convert a native pytorch layer norm module to colossalai layer norm module + Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex, + and optionally marking parameters for gradient aggregation. + + Args: + module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted. + sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. + + Returns: + nn.Module: Union[FastLayerNorm, FusedLayerNorm]. + + Raises: + AssertionError: If the provided module is not an instance of nn.LayerNorm. """ # check if apex is installed + + assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm." + try: pass except ImportError: @@ -85,10 +198,18 @@ class FusedLayerNorm: layernorm.weight = module.weight layernorm.bias = module.bias + + if sp_partial_derived: + # Since gradients are computed using only a subset of the data, + # aggregation of these gradients is necessary during backpropagation. + # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. + SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight) + SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias) + return layernorm -class FusedRMSNorm: +class FusedRMSNorm(BaseLayerNorm): """ This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. """ @@ -96,11 +217,22 @@ class FusedRMSNorm: def __init__(self) -> None: raise NotImplementedError( "FusedRMSNorm is not implemented as a physical class. " - "It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex." + "It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex." ) @staticmethod - def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: + def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: + r""" + Convert a native RMSNorm module module to FusedRMSNorm module provided by apex, + and optionally marking parameters for gradient aggregation. + + Args: + module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted. + sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. + + Returns: + nn.Module: FusedRMSNorm module. + """ try: from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm except ImportError: @@ -124,4 +256,10 @@ class FusedRMSNorm: rmsnorm.weight = module.weight + if sp_partial_derived: + # Since gradients are computed using only a subset of the data, + # aggregation of these gradients is necessary during backpropagation. + # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. + SeqParallelUtils.marked_as_sp_partial_derived_param(rmsnorm.weight) + return rmsnorm diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index c3d8501cd..7421f84bf 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -1,8 +1,82 @@ from contextlib import contextmanager +from typing import List import torch import torch.distributed as dist -from torch.distributed import ProcessGroup +from torch import nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.distributed import ProcessGroup, get_world_size + + +class SeqParallelUtils: + @staticmethod + def marked_as_sp_partial_derived_param(param): + """ + Mark a parameter as partially derived in sequence parallelism. + + Args: + param: The parameter to mark as partially derived. + """ + setattr(param, "partial_derived", True) + + @staticmethod + def is_sp_partial_derived_param(param): + """ + Check if a parameter is marked as partially derived in sequence parallelism. + + Args: + param: The parameter to check. + + Returns: + bool: True if the parameter is marked as partially derived, False otherwise. + """ + return getattr(param, "partial_derived", False) + + @staticmethod + def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None): + """ + Allreduce partial derived gradients across the specified process group. + + This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism. + + Args: + tp_group (ProcessGroup): The process group for gradient synchronization. + model (nn.Module): The model from which gradients will be synchronized. + grads (List[torch.Tensor]): The list of gradients to be synchronized. + + Raises: + AssertionError: If both `model` and `grads` are provided or neither is provided. + """ + # Ensure that exactly one of `model` and `grads` is provided for gradient synchronization. + assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None." + + # Get the size of the process group, which determines whether synchronization is needed. + tp_size = get_world_size(tp_group) if tp_group is not None else 1 + + if tp_size == 1: + # If the process group size is 1, no synchronization is required. + return + + if model is not None: + # If `model` is provided, extract partial derived gradients from the model's parameters. + grads = [] + for p in model.parameters(): + if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p): + grads.append(p.grad.data) + + # Flatten and reduce the gradients using the specified process group. + coalesced = _flatten_dense_tensors(grads) + dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group) + + # Unflatten the synchronized gradients and update the model's gradients. + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + else: + # If `grads` are provided explicitly, synchronize those gradients directly. + coalesced = _flatten_dense_tensors(grads) + dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group) + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) class Randomizer: diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index f3587de15..3014f1cf3 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -4,6 +4,7 @@ from typing import Optional import torch.nn as nn +from ..shard.shard_config import ShardConfig from .base_policy import Policy __all__ = ["PolicyLocation", "get_autopolicy", "import_policy"] @@ -197,7 +198,7 @@ def _fullname(obj): return module + "." + klass.__qualname__ -def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy: +def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy: r""" Return the auto policy for the model @@ -208,7 +209,7 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> :class:`Policy`: The auto policy for the model """ full_name = _fullname(model) - if inference_only: + if shard_config.inference_only: policy_location = _INFER_POLICY_LIST.get(full_name, None) else: policy_location = _POLICY_LIST.get(full_name, None) @@ -218,5 +219,5 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}" ) else: - policy = import_policy(policy_location, inference_only) + policy = import_policy(policy_location, shard_config.inference_only) return policy() diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index eb0350053..003c9322a 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -11,6 +11,7 @@ from torch.nn import Module from colossalai.pipeline.stage_manager import PipelineStageManager +from ..layer.normalization import BaseLayerNorm from ..layer.parallel_module import ParallelModule from ..shard.shard_config import ShardConfig @@ -29,7 +30,7 @@ class SubModuleReplacementDescription: ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception """ suffix: str - target_module: ParallelModule + target_module: Union[ParallelModule, BaseLayerNorm] kwargs: Dict[str, Any] = None ignore_if_not_exist: bool = False @@ -77,7 +78,6 @@ class Policy(ABC): def set_model(self, model: nn.Module) -> None: r""" Set model as an attribute of the Policy object so that we can access the model's attributes. - Args: model (:class:`nn.Module`): The model to be perform """ @@ -86,11 +86,11 @@ class Policy(ABC): def set_shard_config(self, shard_config: ShardConfig) -> None: r""" Set shard config as an attribute of the Policy object. - Args: shard_config (:class:`ShardConfig`): The shard config to be perform """ self.shard_config = shard_config + self.config_sanity_check() @property diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 14146de15..c31327a6c 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -60,6 +60,12 @@ class BertPolicy(Policy): ) policy = {} + + if self.shard_config.enable_fused_normalization: + norm_cls = col_nn.FusedLayerNorm + else: + norm_cls = col_nn.LayerNorm + use_sequence_parallel = self.shard_config.enable_sequence_parallelism overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: @@ -141,33 +147,34 @@ class BertPolicy(Policy): ) # optimization configuration - if self.shard_config.enable_fused_normalization: - # Handle bert layer - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="attention.output.LayerNorm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="output.LayerNorm", - target_module=col_nn.FusedLayerNorm, - ), - ], - policy=policy, - target_key=BertLayer, - ) - # handle embedding layer - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="LayerNorm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=BertEmbeddings, - ) + # Handle bert layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="attention.output.LayerNorm", + target_module=norm_cls, + kwargs={"sp_partial_derived": use_sequence_parallel}, + ), + SubModuleReplacementDescription( + suffix="output.LayerNorm", + target_module=norm_cls, + kwargs={"sp_partial_derived": use_sequence_parallel}, + ), + ], + policy=policy, + target_key=BertLayer, + ) + # handle embedding layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="LayerNorm", + target_module=norm_cls, + ) + ], + policy=policy, + target_key=BertEmbeddings, + ) # use flash attention if self.shard_config.enable_flash_attention: @@ -288,9 +295,6 @@ class BertPolicy(Policy): # BertModel class BertModelPolicy(BertPolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): policy = super().module_policy() from transformers.models.bert.modeling_bert import BertModel @@ -313,9 +317,6 @@ class BertModelPolicy(BertPolicy): # BertForPreTraining class BertForPreTrainingPolicy(BertPolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): policy = super().module_policy() policy = self.add_lm_head_policy(policy) @@ -355,9 +356,6 @@ class BertForPreTrainingPolicy(BertPolicy): # BertLMHeadModel class BertLMHeadModelPolicy(BertPolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): policy = super().module_policy() policy = self.add_lm_head_policy(policy) @@ -396,9 +394,6 @@ class BertLMHeadModelPolicy(BertPolicy): # BertForMaskedLM class BertForMaskedLMPolicy(BertPolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): policy = super().module_policy() policy = self.add_lm_head_policy(policy) @@ -437,9 +432,6 @@ class BertForMaskedLMPolicy(BertPolicy): # BertForSequenceClassification class BertForSequenceClassificationPolicy(BertPolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers.models.bert.modeling_bert import BertForSequenceClassification @@ -484,9 +476,6 @@ class BertForSequenceClassificationPolicy(BertPolicy): # BertForTokenClassification class BertForTokenClassificationPolicy(BertPolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers.models.bert.modeling_bert import BertForTokenClassification @@ -531,9 +520,6 @@ class BertForTokenClassificationPolicy(BertPolicy): # BertForNextSentencePrediction class BertForNextSentencePredictionPolicy(BertPolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): policy = super().module_policy() from transformers.models.bert.modeling_bert import BertForNextSentencePrediction @@ -564,9 +550,6 @@ class BertForNextSentencePredictionPolicy(BertPolicy): # BertForMultipleChoice class BertForMultipleChoicePolicy(BertPolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers.models.bert.modeling_bert import BertForMultipleChoice @@ -610,9 +593,6 @@ class BertForMultipleChoicePolicy(BertPolicy): class BertForQuestionAnsweringPolicy(BertPolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers.models.bert.modeling_bert import BertForQuestionAnswering diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 997643d1a..9be2a1e78 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -43,6 +43,11 @@ class BlipPolicy(Policy): policy = {} + if self.shard_config.enable_fused_normalization: + norm_cls = col_nn.FusedLayerNorm + else: + norm_cls = col_nn.LayerNorm + if self.shard_config.enable_tensor_parallelism: policy[Blip2EncoderLayer] = ModulePolicyDescription( attribute_replacement={ @@ -214,94 +219,93 @@ class BlipPolicy(Policy): policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()}) # optimization configuration - if self.shard_config.enable_fused_normalization: - # Handle Blip2EncoderLayer layer - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="layer_norm1", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="layer_norm2", - target_module=col_nn.FusedLayerNorm, - ), - ], - policy=policy, - target_key=Blip2EncoderLayer, - ) + # Handle Blip2EncoderLayer layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm1", + target_module=norm_cls, + ), + SubModuleReplacementDescription( + suffix="layer_norm2", + target_module=norm_cls, + ), + ], + policy=policy, + target_key=Blip2EncoderLayer, + ) - # handle Blip2VisionModel layer - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="post_layernorm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=Blip2VisionModel, - ) + # handle Blip2VisionModel layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="post_layernorm", + target_module=norm_cls, + ) + ], + policy=policy, + target_key=Blip2VisionModel, + ) - # handle Blip2VisionModel layer - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="layernorm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=Blip2QFormerModel, - ) + # handle Blip2VisionModel layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layernorm", + target_module=norm_cls, + ) + ], + policy=policy, + target_key=Blip2QFormerModel, + ) - # handle Blip2QFormerLayer layer - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="attention.output.LayerNorm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="crossattention.output.LayerNorm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="output_query.LayerNorm", - target_module=col_nn.FusedLayerNorm, - ), - ], - policy=policy, - target_key=Blip2QFormerLayer, - ) + # handle Blip2QFormerLayer layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="attention.output.LayerNorm", + target_module=norm_cls, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.LayerNorm", + target_module=norm_cls, + ), + SubModuleReplacementDescription( + suffix="output_query.LayerNorm", + target_module=norm_cls, + ), + ], + policy=policy, + target_key=Blip2QFormerLayer, + ) - # handle OPTForCausalLM layer - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="model.decoder.final_layer_norm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=OPTForCausalLM, - ) + # handle OPTForCausalLM layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="model.decoder.final_layer_norm", + target_module=norm_cls, + ) + ], + policy=policy, + target_key=OPTForCausalLM, + ) - # handle OPTDecoderLayer layer - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="self_attn_layer_norm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="final_layer_norm", - target_module=col_nn.FusedLayerNorm, - ), - ], - policy=policy, - target_key=OPTDecoderLayer, - ) + # handle OPTDecoderLayer layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", + target_module=norm_cls, + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=norm_cls, + ), + ], + policy=policy, + target_key=OPTDecoderLayer, + ) # use flash attention if self.shard_config.enable_flash_attention: diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 13b9dd313..c8687a1ac 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -42,6 +42,10 @@ class BloomPolicy(Policy): policy = {} + if self.shard_config.enable_fused_normalization: + norm_cls = col_nn.FusedLayerNorm + else: + norm_cls = col_nn.LayerNorm use_sequence_parallel = self.shard_config.enable_sequence_parallelism overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: @@ -97,38 +101,39 @@ class BloomPolicy(Policy): ) # optimization configuration - if self.shard_config.enable_fused_normalization: - # handle bloom model - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="ln_f", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="word_embeddings_layernorm", - target_module=col_nn.FusedLayerNorm, - ), - ], - policy=policy, - target_key=BloomModel, - ) + # handle bloom model + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="ln_f", + target_module=norm_cls, + ), + SubModuleReplacementDescription( + suffix="word_embeddings_layernorm", + target_module=norm_cls, + ), + ], + policy=policy, + target_key=BloomModel, + ) - # handle bloom block - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="input_layernorm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="post_attention_layernorm", - target_module=col_nn.FusedLayerNorm, - ), - ], - policy=policy, - target_key=BloomBlock, - ) + # handle bloom block + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=norm_cls, + kwargs={"sp_partial_derived": use_sequence_parallel}, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=norm_cls, + kwargs={"sp_partial_derived": use_sequence_parallel}, + ), + ], + policy=policy, + target_key=BloomBlock, + ) if use_sequence_parallel: self.append_or_create_method_replacement( @@ -225,9 +230,6 @@ class BloomPolicy(Policy): class BloomModelPolicy(BloomPolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): policy = super().module_policy() from transformers.models.bloom.modeling_bloom import BloomModel diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 3c27c848e..ab18d80b7 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -45,6 +45,16 @@ class ChatGLMPolicy(Policy): policy = {} + if self.shard_config.enable_fused_normalization: + if self.model.config.rmsnorm: + norm_cls = col_nn.FusedRMSNorm + else: + norm_cls = col_nn.FusedLayerNorm + else: + if self.model.config.rmsnorm: + norm_cls = col_nn.RMSNorm + else: + norm_cls = col_nn.LayerNorm use_sequence_parallel = self.shard_config.enable_sequence_parallelism overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: @@ -96,52 +106,34 @@ class ChatGLMPolicy(Policy): ) # optimization configuration - if self.shard_config.enable_fused_normalization: - if not self.model.config.rmsnorm: - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm), - SubModuleReplacementDescription( - suffix="post_attention_layernorm", target_module=col_nn.FusedLayerNorm - ), - ], - policy=policy, - target_key=GLMBlock, - ) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=norm_cls, + kwargs={"sp_partial_derived": use_sequence_parallel}, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=norm_cls, + kwargs={"sp_partial_derived": use_sequence_parallel}, + ), + ], + policy=policy, + target_key=GLMBlock, + ) - if self.model.config.post_layer_norm: - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="encoder.final_layernorm", target_module=col_nn.FusedLayerNorm - ) - ], - policy=policy, - target_key=ChatGLMModel, - ) - - else: - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm), - SubModuleReplacementDescription( - suffix="post_attention_layernorm", target_module=col_nn.FusedRMSNorm - ), - ], - policy=policy, - target_key=GLMBlock, - ) - - if self.model.config.post_layer_norm: - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="encoder.final_layernorm", target_module=col_nn.FusedRMSNorm - ) - ], - policy=policy, - target_key=ChatGLMModel, + if self.model.config.post_layer_norm: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="encoder.final_layernorm", + target_module=norm_cls, ) + ], + policy=policy, + target_key=ChatGLMModel, + ) # use flash attention if self.shard_config.enable_flash_attention: @@ -224,9 +216,6 @@ class ChatGLMPolicy(Policy): class ChatGLMModelPolicy(ChatGLMPolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): pass diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 6f46bfc7e..022e6ff5b 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -39,6 +39,11 @@ class GPT2Policy(Policy): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model policy = {} + + if self.shard_config.enable_fused_normalization: + norm_cls = col_nn.FusedLayerNorm + else: + norm_cls = col_nn.LayerNorm use_sequence_parallel = self.shard_config.enable_sequence_parallelism overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: @@ -102,33 +107,37 @@ class GPT2Policy(Policy): ) # optimization configuration - if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement( - description=SubModuleReplacementDescription( - suffix="ln_f", - target_module=col_nn.FusedLayerNorm, - ), - policy=policy, - target_key=GPT2Model, - ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="ln_f", + target_module=norm_cls, + ), + policy=policy, + target_key=GPT2Model, + ) - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="ln_1", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="ln_2", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="ln_cross_attn", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True - ), - ], - policy=policy, - target_key=GPT2Block, - ) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="ln_1", + target_module=norm_cls, + kwargs={"sp_partial_derived": use_sequence_parallel}, + ), + SubModuleReplacementDescription( + suffix="ln_2", + target_module=norm_cls, + kwargs={"sp_partial_derived": use_sequence_parallel}, + ), + SubModuleReplacementDescription( + suffix="ln_cross_attn", + target_module=norm_cls, + ignore_if_not_exist=True, + kwargs={"sp_partial_derived": use_sequence_parallel}, + ), + ], + policy=policy, + target_key=GPT2Block, + ) if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement( @@ -192,9 +201,6 @@ class GPT2Policy(Policy): # GPT2Model class GPT2ModelPolicy(GPT2Policy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2Model @@ -216,9 +222,6 @@ class GPT2ModelPolicy(GPT2Policy): # GPT2LMHeadModel class GPT2LMHeadModelPolicy(GPT2Policy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel @@ -263,9 +266,6 @@ class GPT2LMHeadModelPolicy(GPT2Policy): # GPT2DoubleHeadsModel class GPT2DoubleHeadsModelPolicy(GPT2Policy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel @@ -317,9 +317,6 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): # GPT2ForQuestionAnswering class GPT2ForQuestionAnsweringPolicy(GPT2Policy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering @@ -347,9 +344,6 @@ class GPT2ForQuestionAnsweringPolicy(GPT2Policy): # GPT2ForTokenClassification class GPT2ForTokenClassificationPolicy(GPT2Policy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification @@ -387,9 +381,6 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy): # GPT2ForSequenceClassification class GPT2ForSequenceClassificationPolicy(GPT2Policy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 099995acb..915f07d31 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -6,7 +6,7 @@ import torch.nn as nn from torch import Tensor from torch.nn import Module -from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -35,6 +35,11 @@ class LlamaPolicy(Policy): policy = {} + if self.shard_config.enable_fused_normalization: + norm_cls = FusedRMSNorm + else: + norm_cls = RMSNorm + if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") @@ -93,31 +98,31 @@ class LlamaPolicy(Policy): ) # optimization configuration - if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="input_layernorm", - target_module=FusedRMSNorm, - ), - SubModuleReplacementDescription( - suffix="post_attention_layernorm", - target_module=FusedRMSNorm, - ), - ], - policy=policy, - target_key=LlamaDecoderLayer, - ) - - self.append_or_create_submodule_replacement( - description=SubModuleReplacementDescription( - suffix="norm", - target_module=FusedRMSNorm, + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=norm_cls, ), - policy=policy, - target_key=LlamaModel, - ) + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=norm_cls, + ), + ], + policy=policy, + target_key=LlamaDecoderLayer, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", + target_module=norm_cls, + ), + policy=policy, + target_key=LlamaModel, + ) + + # use flash attention if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement( description={ @@ -174,9 +179,6 @@ class LlamaPolicy(Policy): class LlamaModelPolicy(LlamaPolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): policy = super().module_policy() from transformers.models.llama.modeling_llama import LlamaModel diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 5739d21a3..0b5c767e1 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -5,7 +5,7 @@ from typing import Callable, Dict, List import torch.nn as nn from torch import Tensor, nn -from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .._utils import getattr_ from ..modeling.jit import get_jit_fused_dropout_add_func @@ -42,6 +42,12 @@ class OPTPolicy(Policy): from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer policy = {} + + if self.shard_config.enable_fused_normalization: + norm_cls = FusedLayerNorm + else: + norm_cls = LayerNorm + if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") @@ -94,26 +100,25 @@ class OPTPolicy(Policy): ) # optimization configuration - if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement( - description=SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True + ), + policy=policy, + target_key=OPTDecoder, + ) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", target_module=norm_cls, ignore_if_not_exist=True ), - policy=policy, - target_key=OPTDecoder, - ) - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True - ), - SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True - ), - ], - policy=policy, - target_key=OPTDecoderLayer, - ) + SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True + ), + ], + policy=policy, + target_key=OPTDecoderLayer, + ) # use flash attention if self.shard_config.enable_flash_attention: @@ -183,9 +188,6 @@ class OPTPolicy(Policy): class OPTModelPolicy(OPTPolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers.models.opt.modeling_opt import OPTModel @@ -253,9 +255,6 @@ class OPTForCausalLMPolicy(OPTPolicy): class OPTForSequenceClassificationPolicy(OPTPolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers.models.opt.modeling_opt import OPTForSequenceClassification @@ -281,9 +280,6 @@ class OPTForSequenceClassificationPolicy(OPTPolicy): class OPTForQuestionAnsweringPolicy(OPTPolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers.models.opt.modeling_opt import OPTForQuestionAnswering diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index 58a8500e3..498e62164 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -24,6 +24,11 @@ class SamPolicy(Policy): policy = {} + if self.shard_config.enable_fused_normalization: + norm_cls = col_nn.FusedLayerNorm + else: + norm_cls = col_nn.LayerNorm + if self.shard_config.enable_tensor_parallelism: policy[SamVisionLayer] = ModulePolicyDescription( attribute_replacement={ @@ -151,58 +156,57 @@ class SamPolicy(Policy): ) # optimization configuration - if self.shard_config.enable_fused_normalization: - # Handle SamVisionLayer - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="layer_norm1", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="layer_norm2", - target_module=col_nn.FusedLayerNorm, - ), - ], - policy=policy, - target_key=SamVisionLayer, - ) + # Handle SamVisionLayer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm1", + target_module=norm_cls, + ), + SubModuleReplacementDescription( + suffix="layer_norm2", + target_module=norm_cls, + ), + ], + policy=policy, + target_key=SamVisionLayer, + ) - # Handle SamTwoWayAttentionBlock - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="layer_norm1", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="layer_norm2", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="layer_norm3", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="layer_norm4", - target_module=col_nn.FusedLayerNorm, - ), - ], - policy=policy, - target_key=SamTwoWayAttentionBlock, - ) + # Handle SamTwoWayAttentionBlock + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm1", + target_module=norm_cls, + ), + SubModuleReplacementDescription( + suffix="layer_norm2", + target_module=norm_cls, + ), + SubModuleReplacementDescription( + suffix="layer_norm3", + target_module=norm_cls, + ), + SubModuleReplacementDescription( + suffix="layer_norm4", + target_module=norm_cls, + ), + ], + policy=policy, + target_key=SamTwoWayAttentionBlock, + ) - # Handle SamTwoWayTransformer - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="layer_norm_final_attn", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=SamTwoWayTransformer, - ) + # Handle SamTwoWayTransformer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm_final_attn", + target_module=norm_cls, + ) + ], + policy=policy, + target_key=SamTwoWayTransformer, + ) # use flash attention if self.shard_config.enable_flash_attention: diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 74cc7337e..fc5021600 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -11,6 +11,7 @@ from colossalai.shardformer.layer import ( FusedRMSNorm, Linear1D_Col, Linear1D_Row, + RMSNorm, VocabParallelEmbedding1D, ) from colossalai.shardformer.policies.base_policy import ModulePolicyDescription @@ -58,6 +59,11 @@ class T5BasePolicy(Policy): policy = {} + if self.shard_config.enable_fused_normalization: + norm_cls = FusedRMSNorm + else: + norm_cls = RMSNorm + if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") @@ -169,38 +175,37 @@ class T5BasePolicy(Policy): ) # optimization configuration - if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement( - description=SubModuleReplacementDescription( - suffix="layer_norm", - target_module=FusedRMSNorm, - ), - policy=policy, - target_key=T5LayerFF, - ) - self.append_or_create_submodule_replacement( - description=SubModuleReplacementDescription( - suffix="layer_norm", - target_module=FusedRMSNorm, - ), - policy=policy, - target_key=T5LayerFF, - ) - self.append_or_create_submodule_replacement( - description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm), - policy=policy, - target_key=T5LayerSelfAttention, - ) - self.append_or_create_submodule_replacement( - description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm), - policy=policy, - target_key=T5LayerCrossAttention, - ) - self.append_or_create_submodule_replacement( - description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm), - policy=policy, - target_key=T5Stack, - ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="layer_norm", + target_module=norm_cls, + ), + policy=policy, + target_key=T5LayerFF, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="layer_norm", + target_module=norm_cls, + ), + policy=policy, + target_key=T5LayerFF, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls), + policy=policy, + target_key=T5LayerSelfAttention, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls), + policy=policy, + target_key=T5LayerCrossAttention, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=norm_cls), + policy=policy, + target_key=T5Stack, + ) # use flash attention if self.shard_config.enable_flash_attention: @@ -363,9 +368,6 @@ class T5BasePolicy(Policy): class T5ModelPolicy(T5BasePolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers import T5Model @@ -402,9 +404,6 @@ class T5ModelPolicy(T5BasePolicy): class T5ForConditionalGenerationPolicy(T5BasePolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers import T5ForConditionalGeneration @@ -466,9 +465,6 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy): class T5EncoderPolicy(T5BasePolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers import T5EncoderModel diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 270cdce9b..6ef0e3b34 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -159,9 +159,6 @@ class ViTPolicy(Policy): # ViTModel class ViTModelPolicy(ViTPolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers.models.vit.modeling_vit import ViTModel @@ -227,9 +224,6 @@ class ViTForImageClassificationPolicy(ViTPolicy): # ViTForMaskedImageModeling class ViTForMaskedImageModelingPolicy(ViTPolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index d9af2461c..3ce198e9e 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -52,6 +52,11 @@ class WhisperPolicy(Policy): policy = {} + if self.shard_config.enable_fused_normalization: + norm_cls = col_nn.FusedLayerNorm + else: + norm_cls = col_nn.LayerNorm + if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn( @@ -161,62 +166,61 @@ class WhisperPolicy(Policy): ) # optimization configuration - if self.shard_config.enable_fused_normalization: - # Handle encoder layer - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="self_attn_layer_norm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="final_layer_norm", - target_module=col_nn.FusedLayerNorm, - ), - ], - policy=policy, - target_key=WhisperEncoderLayer, - ) + # Handle encoder layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", + target_module=norm_cls, + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=norm_cls, + ), + ], + policy=policy, + target_key=WhisperEncoderLayer, + ) - # Handle decoder layer - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="self_attn_layer_norm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="final_layer_norm", - target_module=col_nn.FusedLayerNorm, - ), - ], - policy=policy, - target_key=WhisperDecoderLayer, - ) + # Handle decoder layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", + target_module=norm_cls, + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=norm_cls, + ), + ], + policy=policy, + target_key=WhisperDecoderLayer, + ) - # handle encoder layer - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="layer_norm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=WhisperEncoder, - ) + # handle encoder layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm", + target_module=norm_cls, + ) + ], + policy=policy, + target_key=WhisperEncoder, + ) - # handle decoder layer - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="layer_norm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=WhisperDecoder, - ) + # handle decoder layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm", + target_module=norm_cls, + ) + ], + policy=policy, + target_key=WhisperDecoder, + ) # enable flash attention if self.shard_config.enable_flash_attention: @@ -416,9 +420,6 @@ class WhisperPolicy(Policy): # WhisperModel class WhisperModelPolicy(WhisperPolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers import WhisperModel @@ -441,9 +442,6 @@ class WhisperModelPolicy(WhisperPolicy): # WhisperForConditionalGeneration class WhisperForConditionalGenerationPolicy(WhisperPolicy): - def __init__(self) -> None: - super().__init__() - def module_policy(self): from transformers import WhisperForConditionalGeneration @@ -502,9 +500,6 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy): # WhisperForAudioClassification class WhisperForAudioClassificationPolicy(WhisperPolicy): - def __init__(self) -> None: - super().__init__() - def preprocess(self): return self.model diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 1bed850c6..e3c0aa93d 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -27,8 +27,8 @@ class ModelSharder(object): def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None: self.model = model - self.policy = get_autopolicy(self.model, shard_config.inference_only) if policy is None else policy self.shard_config = shard_config + self.policy = get_autopolicy(self.model, shard_config) if policy is None else policy def shard(self) -> List[Dict[int, Tensor]]: r""" @@ -196,7 +196,7 @@ class ModelSharder(object): try: replace_layer = target_module.from_native_module( - native_sub_module, self.shard_config.tensor_parallel_process_group, **kwargs + native_sub_module, process_group=self.shard_config.tensor_parallel_process_group, **kwargs ) except Exception as e: raise RuntimeError( diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 932053dd1..d61082bed 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -700,7 +700,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ############################ # this method is used to sync gradient manually - def sync_grad(self): + def _sync_grad(self): for group_id in range(self.num_param_groups): param_group = self._working_param_groups[group_id] for param in param_group: @@ -713,7 +713,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # if not overlapping communication (no reduction hook is attached) when zero1 # we need to manually reduce these gradients if not partition_grad and not self._overlap_communication: - self.sync_grad() + self._sync_grad() else: self._run_reduction() diff --git a/tests/test_infer/test_pipeline_infer.py b/tests/test_infer/test_pipeline_infer.py index 3544153da..1cf38c1ec 100644 --- a/tests/test_infer/test_pipeline_infer.py +++ b/tests/test_infer/test_pipeline_infer.py @@ -24,7 +24,6 @@ for k, v in inputs.items(): new_shape[0] = 16 inputs[k] = v.to("cuda").repeat(*new_shape) - def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): model = transformers.LlamaForCausalLM( transformers.LlamaConfig( @@ -59,6 +58,7 @@ def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_si @parameterize("pp_size", [2]) @parameterize("max_output_len", [4]) @parameterize("micro_batch_size", [1]) + @clear_cache_before_run() def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) @@ -76,6 +76,7 @@ def check_tp_pipeline_inference(rank, world_size, port): @pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py index 0192afc99..9e7336b93 100644 --- a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py +++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py @@ -128,7 +128,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 1, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": True, "precision": "fp16", "max_norm": 5, @@ -137,7 +137,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 2, "pp_size": 1, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp16", "max_norm": 5, @@ -147,7 +147,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp16", "max_norm": 5, @@ -157,7 +157,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 1, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": True, "precision": "bf16", "max_norm": 5, @@ -165,7 +165,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 2, "pp_size": 1, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "bf16", "max_norm": 5, @@ -174,7 +174,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "bf16", "max_norm": 5, @@ -199,7 +199,7 @@ def run_test(test_config): "tp_size": 2, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "bf16", "max_norm": 5, @@ -208,7 +208,7 @@ def run_test(test_config): "tp_size": 2, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp16", "max_norm": 5, diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py index da298f5c0..b8ead795d 100644 --- a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py +++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py @@ -106,7 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 1, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": True, "precision": "fp32", "max_norm": 5, @@ -114,7 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 2, "pp_size": 1, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32", "max_norm": 5, @@ -123,7 +123,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32", "max_norm": 5, @@ -148,7 +148,7 @@ def run_test(test_config): "tp_size": 2, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32", "max_norm": 5, diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py index f1ac1de1a..061c70255 100644 --- a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py +++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py @@ -106,7 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": True, "precision": "fp16", "max_norm": 5, @@ -116,7 +116,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 1, "zero_stage": 1, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp16", "max_norm": 5, @@ -126,7 +126,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 1, "zero_stage": 2, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp16", "max_norm": 5, @@ -137,7 +137,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": True, "precision": "bf16", "max_norm": 5, @@ -146,7 +146,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 1, "zero_stage": 1, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "bf16", "max_norm": 5, @@ -155,7 +155,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 1, "zero_stage": 2, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "bf16", "max_norm": 5, @@ -181,7 +181,7 @@ def run_test(test_config): "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "bf16", "max_norm": 5, @@ -191,7 +191,7 @@ def run_test(test_config): "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp16", "max_norm": 5, diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 31fd58d06..b38793b7c 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -34,6 +34,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, bert = unwrap_model(org_model, "BertModel", "bert") sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") + norm_layer_for_check = ["encoder.layer[0].attention.output.LayerNorm", "embeddings.LayerNorm"] col_layer_for_check = ["encoder.layer[0].output.dense"] row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"] @@ -50,8 +51,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, row_layer_grads = get_grad_tensors_for_check( bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False ) + + norm_layer_grads = get_grad_tensors_for_check( + bert, + sharded_bert, + norm_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) + grads_to_check.update(norm_layer_grads) # optimizer executes step org_optimizer.step() @@ -85,6 +99,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp32", + }, { "tp_size": 1, "pp_size": 2, diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index 7fe791db6..b70cba8b4 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -35,6 +35,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, bloom = unwrap_model(org_model, "BloomModel", "transformer") sharded_bloom = unwrap_model(sharded_model, "BloomModel", "transformer") + norm_layer_for_check = ["word_embeddings_layernorm", "h[0].input_layernorm"] row_layer_for_check = ["h[0].self_attention.query_key_value", "word_embeddings"] col_layer_for_check = ["h[0].self_attention.dense"] @@ -51,8 +52,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, col_layer_grads = get_grad_tensors_for_check( bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False ) + + norm_layer_grads = get_grad_tensors_for_check( + bloom, + sharded_bloom, + norm_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) + grads_to_check.update(norm_layer_grads) # optimizer executes step org_optimizer.step() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index bdf5b79fc..29d3592bf 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -35,6 +35,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, chatglm_model = unwrap_model(org_model, "ChatGLMModel", "transformer") shard_chatglm_model = unwrap_model(sharded_model, "ChatGLMModel", "transformer") + norm_layer_for_check = ["encoder.layers[0].input_layernorm"] row_layer_for_check = ["encoder.layers[0].self_attention.query_key_value", "embedding.word_embeddings"] col_layer_for_check = ["encoder.layers[0].self_attention.dense"] @@ -66,8 +67,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=1, verbose=False, ) + + norm_layer_grads = get_grad_tensors_for_check( + chatglm_model, + shard_chatglm_model, + norm_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) + grads_to_check.update(norm_layer_grads) # optimizer executes step org_optimizer.step() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 69a15166a..66b30641a 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -35,6 +35,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, gpt2 = unwrap_model(org_model, "GPT2Model", "transformer") sharded_gpt2 = unwrap_model(sharded_model, "GPT2Model", "transformer") + norm_layer_for_check = ["h[0].ln_1", "h[0].ln_2"] col_layer_for_check = ["h[0].mlp.c_fc"] row_layer_for_check = ["wte", "h[0].mlp.c_proj"] @@ -51,8 +52,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, row_layer_grads = get_grad_tensors_for_check( gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False ) + + norm_layer_grads = get_grad_tensors_for_check( + gpt2, + sharded_gpt2, + norm_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) + grads_to_check.update(norm_layer_grads) # optimizer executes step org_optimizer.step()