From 21624f6f811f73069ecde86116e19e9ad161616d Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Wed, 1 Nov 2023 11:29:55 +0800 Subject: [PATCH] fix(moe): remove norm&gate force sync (#448) * add zero broadcast_sync * delete old sync logic * fix merged error * refactor code * remove some unused function (is norm/gate group) --- internlm/initialize/launch.py | 10 +++--- internlm/model/modeling_moe.py | 5 --- internlm/model/utils.py | 12 ------- internlm/moe/sharded_moe.py | 3 -- .../solver/optimizer/hybrid_zero_optim.py | 31 ------------------- internlm/solver/optimizer/utils.py | 23 +++++++++----- internlm/train/utils.py | 14 ++------- 7 files changed, 23 insertions(+), 75 deletions(-) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index b7c9199..ad404f2 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -293,10 +293,6 @@ def args_sanity_check(): model._add_item("moe_use_residual", False) if "moe_gate_k" not in model: model._add_item("moe_gate_k", 2) - assert not ( - gpc.config.model.num_experts > 1 and gpc.config.parallel.zero1.fsdp - ), "FSDP does not support num_experts > 1" - # process the parallel config if "sequence_parallel" not in gpc.config.parallel: gpc.config.parallel._add_item("sequence_parallel", False) @@ -345,11 +341,13 @@ def args_sanity_check(): gpc.config.loss._add_item("moe_loss_coeff", 1.0) # moe not support overlap and zero1.5 for now - if hasattr(gpc.config.model, "num_experts"): + if gpc.config.model.get("num_experts", 1) > 1: + assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support num_experts > 1" assert ( not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param ), "not support overlap and moe at the same time" assert gpc.config.parallel.zero1.size == -1, "moe only support zero1, set zero1=dict(size=-1,...) can fix this" + assert not gpc.config.parallel.sequence_parallel, "moe not support sequence parallel for now" def launch( @@ -413,7 +411,7 @@ def launch( f"data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, " f"tensor parallel size: {gpc.tensor_parallel_size}", ) - if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + if gpc.config.model.get("num_experts", 1) > 1: logger.info( f"Creating MoE with num_experts: {gpc.config.model.num_experts} | " f"expert parallel size: {gpc.expert_parallel_size} | " diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index 43489bc..df6c7a8 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -127,11 +127,6 @@ class PackedFlashBaseLayer1D(nn.Module): else: self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - - for param in self.norm1.parameters(): - param.is_norm = True - for param in self.norm2.parameters(): - param.is_norm = True set_fp32_attr_to_module(self.norm1) set_fp32_attr_to_module(self.norm2) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 570a86f..46fba59 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -215,18 +215,6 @@ def is_moe_param(param: torch.Tensor) -> bool: return False -def is_gate_param(param: torch.Tensor) -> bool: - if hasattr(param, "is_gate") and param.is_gate: - return True - return False - - -def is_norm_param(param: torch.Tensor) -> bool: - if hasattr(param, "is_norm") and param.is_norm: - return True - return False - - def Silu(w1_o, w2_o): return F.silu(w1_o) * w2_o diff --git a/internlm/moe/sharded_moe.py b/internlm/moe/sharded_moe.py index dbee2a4..5d695ac 100644 --- a/internlm/moe/sharded_moe.py +++ b/internlm/moe/sharded_moe.py @@ -352,9 +352,6 @@ class TopKGate(Module): self.drop_tokens = drop_tokens self.use_rts = use_rts - for param in self.wg.parameters(): - param.is_gate = True - def forward( self, inputs: torch.Tensor, used_token: torch.Tensor = None ) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 2901f81..d1edb4f 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -274,12 +274,6 @@ class HybridZeroOptimizer(BaseOptimizer): def _is_moe_group(self, param_group): return "moe" in param_group.keys() and param_group["moe"] - def _is_norm_group(self, param_group): - return "norm" in param_group.keys() and param_group["norm"] - - def _is_gate_group(self, param_group): - return "gate" in param_group.keys() and param_group["gate"] - # TODO check expert dp is correct when enable moe and overlap both def _attach_reduction_hook(self): # we iterate over the fp16 params @@ -566,7 +560,6 @@ class HybridZeroOptimizer(BaseOptimizer): last_stage=last_stage, previous_param_norms=previous_param_norms, zero_mode=self._broadcast_parallel_mode[group_id], - is_moe_group=self._is_moe_group(self.optim.param_groups[group_id]), ) return total_param_norms @@ -589,7 +582,6 @@ class HybridZeroOptimizer(BaseOptimizer): last_stage=last_stage, previous_zero_grad_count=previous_zero_grad_count, zero_mode=self._broadcast_parallel_mode[group_id], - is_moe_group=self._is_moe_group(self.optim.param_groups[group_id]), ) return total_zero_grad_count @@ -675,16 +667,6 @@ class HybridZeroOptimizer(BaseOptimizer): total_param_zero_grad_count[group_name], ) = compute_layer_zero_grad_count(zero_grad_count) - # Need to allreduce(avg) the norms across different ranks because moe params will not be synced - # during allreduce - if self._is_moe_group(self.optim.param_groups[group_id]): - # model and zero have been reduced!!! - pg = gpc.get_group(ParallelMode.EXPERT) - scaled_norm = total_norms[group_name] * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT)) - scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) - dist.all_reduce(scaled_norm_tensor, group=pg) - total_norms[group_name] = scaled_norm_tensor.item() - timer("sync_grad").start() self._sync_grad() timer("sync_grad").stop() @@ -767,19 +749,6 @@ class HybridZeroOptimizer(BaseOptimizer): param_shape == flat_fp32_avg_grads.shape ), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}" - # Parameters shared within a TP group, such as norm and moe gate, have precision inconsistency in gradients. - # Therefore, it is recommended to synchronize gradients within the TP group to eliminate accumulated errors. - is_tp_sync_groups = ( - self._is_norm_group(self.optim.param_groups[group_id]), - self._is_gate_group(self.optim.param_groups[group_id]), - ) - if any(is_tp_sync_groups): - dist.all_reduce( - flat_fp32_avg_grads, - op=dist.ReduceOp.AVG, - group=gpc.get_group(ParallelMode.TENSOR), - ) - single_grad_partition_groups.append(flat_fp32_avg_grads) device = self._fp32_flat_param_groups_of_current_rank[group_id].device self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device) diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 2fb8f57..0cc7451 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -262,7 +262,12 @@ def reduce_grads(gradients, parameters, fine_grained=False): def compute_norm( - gradients, parameters, last_stage=False, previous_norm=None, norm_type=2, zero_mode=ParallelMode.ZERO1 + gradients, + parameters, + last_stage=False, + previous_norm=None, + norm_type=2, + zero_mode=ParallelMode.ZERO1, ): """Get the norm Arguments: @@ -335,6 +340,15 @@ def compute_norm( if torch.is_tensor(total_norm): total_norm = total_norm.item() + # Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce + # model and zero have been reduced!!! + if zero_mode == ParallelMode.EXPERT_DATA: + pg = gpc.get_group(ParallelMode.EXPERT) + scaled_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA)) + scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) + dist.all_reduce(scaled_norm_tensor, group=pg) + total_norm = scaled_norm_tensor.item() + # Scale. if total_norm == float("inf") or total_norm == -float("inf"): total_norm = -1 @@ -353,7 +367,6 @@ def compute_param_metric( previous_param_metrics=None, norm_type=2, zero_mode=ParallelMode.ZERO1, - is_moe_group=False, ): """Get the metrics of params Argumemts: @@ -424,7 +437,7 @@ def compute_param_metric( total_metrics[param_name] += param_metric # moe - if is_moe_group: + if zero_mode == ParallelMode.EXPERT_DATA: pg = gpc.get_group(ParallelMode.EXPERT) total_metric_values = list(total_metrics.values()) if isinstance(total_metric_values[0], torch.Tensor): @@ -463,7 +476,6 @@ def compute_param_norm( previous_param_norms=None, norm_type=2, zero_mode=ParallelMode.ZERO1, - is_moe_group=False, ): """Get the norm of params Arguments: @@ -484,7 +496,6 @@ def compute_param_norm( previous_param_metrics=previous_param_norms, norm_type=norm_type, zero_mode=zero_mode, - is_moe_group=is_moe_group, ) @@ -494,7 +505,6 @@ def compute_zero_grad_count( last_stage=False, previous_zero_grad_count=None, zero_mode=ParallelMode.ZERO1, - is_moe_group=False, ): """Get the count of zero gradient for each parameters Arguments: @@ -512,7 +522,6 @@ def compute_zero_grad_count( last_stage=last_stage, previous_param_metrics=previous_zero_grad_count, zero_mode=zero_mode, - is_moe_group=is_moe_group, ) diff --git a/internlm/train/utils.py b/internlm/train/utils.py index 9096a2a..d4c2203 100644 --- a/internlm/train/utils.py +++ b/internlm/train/utils.py @@ -4,7 +4,7 @@ import torch from internlm.core.context.parallel_context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc -from internlm.model.utils import is_gate_param, is_moe_param, is_norm_param +from internlm.model.utils import is_moe_param def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) -> Tuple[Dict]: @@ -39,10 +39,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) # create new groups for fp32, norm, moe gate and moe expert new_groups = {} new_groups["fp32"] = {"name": "fp32", "params": [], "dp_mode": ParallelMode.DATA} - if gpc.config.model.get("num_experts", 0) > 1: - # norm and gate are special group to force sync (when enable MoE). - for key in ["gate", "norm"]: - new_groups[key] = {"name": key, key: True, "params": [], "dp_mode": ParallelMode.DATA} + if gpc.config.model.get("num_experts", 1) > 1: for key in gpc.expert_parallel_group_names: new_groups[key] = {"name": key, "moe": True, "params": [], "dp_mode": ParallelMode.EXPERT_DATA} @@ -58,12 +55,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) # first split the norm and gate groups, which are special case to force sync (when enable MoE), # then fp32 group and the moe group. for param in pgroup["params"]: - if gpc.config.model.get("num_experts", 0) > 1 and is_norm_param(param): - new_groups["norm"]["params"].append(param) - # gate param means MoE is enabled - elif is_gate_param(param): - new_groups["gate"]["params"].append(param) - elif param.dtype == torch.float32: + if param.dtype == torch.float32: new_groups["fp32"]["params"].append(param) # moe param means MoE is enabled elif is_moe_param(param):