From 739a308c825c028e9922ecc2cca044fbc1c4e5d7 Mon Sep 17 00:00:00 2001 From: Qu Wenwen Date: Fri, 27 Oct 2023 15:10:16 +0800 Subject: [PATCH] fix merged error --- internlm/initialize/launch.py | 10 +++--- .../solver/optimizer/hybrid_zero_optim.py | 27 ---------------- internlm/solver/optimizer/utils.py | 32 +++++++++++++------ internlm/train/utils.py | 2 +- 4 files changed, 28 insertions(+), 43 deletions(-) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index b7c9199..ad404f2 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -293,10 +293,6 @@ def args_sanity_check(): model._add_item("moe_use_residual", False) if "moe_gate_k" not in model: model._add_item("moe_gate_k", 2) - assert not ( - gpc.config.model.num_experts > 1 and gpc.config.parallel.zero1.fsdp - ), "FSDP does not support num_experts > 1" - # process the parallel config if "sequence_parallel" not in gpc.config.parallel: gpc.config.parallel._add_item("sequence_parallel", False) @@ -345,11 +341,13 @@ def args_sanity_check(): gpc.config.loss._add_item("moe_loss_coeff", 1.0) # moe not support overlap and zero1.5 for now - if hasattr(gpc.config.model, "num_experts"): + if gpc.config.model.get("num_experts", 1) > 1: + assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support num_experts > 1" assert ( not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param ), "not support overlap and moe at the same time" assert gpc.config.parallel.zero1.size == -1, "moe only support zero1, set zero1=dict(size=-1,...) can fix this" + assert not gpc.config.parallel.sequence_parallel, "moe not support sequence parallel for now" def launch( @@ -413,7 +411,7 @@ def launch( f"data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, " f"tensor parallel size: {gpc.tensor_parallel_size}", ) - if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + if gpc.config.model.get("num_experts", 1) > 1: logger.info( f"Creating MoE with num_experts: {gpc.config.model.num_experts} | " f"expert parallel size: {gpc.expert_parallel_size} | " diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 2dc2983..ce87566 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -560,7 +560,6 @@ class HybridZeroOptimizer(BaseOptimizer): last_stage=last_stage, previous_param_norms=previous_param_norms, zero_mode=self._broadcast_parallel_mode[group_id], - is_moe_group=self._is_moe_group(self.optim.param_groups[group_id]), ) return total_param_norms @@ -630,16 +629,6 @@ class HybridZeroOptimizer(BaseOptimizer): param_norms=param_norms, loss_scale=self.loss_scale.item() ) - # Need to allreduce(avg) the norms across different ranks because moe params will not be synced - # during allreduce - if self._is_moe_group(self.optim.param_groups[group_id]): - # model and zero have been reduced!!! - pg = gpc.get_group(ParallelMode.EXPERT) - scaled_norm = total_norms[group_name] * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT)) - scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) - dist.all_reduce(scaled_norm_tensor, group=pg) - total_norms[group_name] = scaled_norm_tensor.item() - timer("sync_grad").start() self._sync_grad() timer("sync_grad").stop() @@ -719,19 +708,6 @@ class HybridZeroOptimizer(BaseOptimizer): param_shape == flat_fp32_avg_grads.shape ), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}" - # Parameters shared within a TP group, such as norm and moe gate, have precision inconsistency in gradients. - # Therefore, it is recommended to synchronize gradients within the TP group to eliminate accumulated errors. - # is_tp_sync_groups = ( - # self._is_norm_group(self.optim.param_groups[group_id]), - # self._is_gate_group(self.optim.param_groups[group_id]), - # ) - # if any(is_tp_sync_groups): - # dist.all_reduce( - # flat_fp32_avg_grads, - # op=dist.ReduceOp.AVG, - # group=gpc.get_group(ParallelMode.TENSOR), - # ) - single_grad_partition_groups.append(flat_fp32_avg_grads) device = self._fp32_flat_param_groups_of_current_rank[group_id].device self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device) @@ -774,9 +750,6 @@ class HybridZeroOptimizer(BaseOptimizer): with torch.cuda.stream(self._comm_bcast_stream): self.broadcast_params() - if not self._overlap_sync_param: - torch.cuda.synchronize() - timer("step").stop() # update gradients may not be needed here, because the sync_params function is used in initialization, diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 982a246..2f7e21f 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -252,8 +252,23 @@ def reduce_grads(gradients, parameters, fine_grained=False): return parallel_grads +def reduce_moe_norm(total_norm): + pg = gpc.get_group(ParallelMode.EXPERT) + scaled_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA)) + scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) + dist.all_reduce(scaled_norm_tensor, group=pg) + total_norm = scaled_norm_tensor.item() + + return total_norm + + def compute_norm( - gradients, parameters, last_stage=False, previous_norm=None, norm_type=2, zero_mode=ParallelMode.ZERO1 + gradients, + parameters, + last_stage=False, + previous_norm=None, + norm_type=2, + zero_mode=ParallelMode.ZERO1, ): """Get the norm Arguments: @@ -326,6 +341,11 @@ def compute_norm( if torch.is_tensor(total_norm): total_norm = total_norm.item() + # Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce + # model and zero have been reduced!!! + if zero_mode == ParallelMode.EXPERT_DATA: + total_norm = reduce_moe_norm(total_norm) + # Scale. if total_norm == float("inf") or total_norm == -float("inf"): total_norm = -1 @@ -343,7 +363,6 @@ def compute_param_norm( previous_param_norms=None, norm_type=2, zero_mode=ParallelMode.ZERO1, - is_moe_group=False, ): """Get the norm of params Arguments: @@ -413,13 +432,8 @@ def compute_param_norm( total_param_norms[param_name] += param_norm # moe - if is_moe_group: - pg = gpc.get_group(ParallelMode.EXPERT) - scaled_param_norm = torch.cuda.FloatTensor(list(total_param_norms.values()), device=get_current_device()) - scaled_param_norm = scaled_param_norm / float(gpc.get_world_size(ParallelMode.EXPERT)) - dist.all_reduce(scaled_param_norm, group=pg) - for i, param_name in enumerate(total_param_norms.keys()): - total_param_norms[param_name] = scaled_param_norm[i].item() + if zero_mode == ParallelMode.EXPERT_DATA: + total_param_norms = reduce_moe_norm(total_param_norms) # scale for param_name, param_norm in total_param_norms.items(): diff --git a/internlm/train/utils.py b/internlm/train/utils.py index ff59597..d4c2203 100644 --- a/internlm/train/utils.py +++ b/internlm/train/utils.py @@ -39,7 +39,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) # create new groups for fp32, norm, moe gate and moe expert new_groups = {} new_groups["fp32"] = {"name": "fp32", "params": [], "dp_mode": ParallelMode.DATA} - if gpc.config.model.get("num_experts", 0) > 1: + if gpc.config.model.get("num_experts", 1) > 1: for key in gpc.expert_parallel_group_names: new_groups[key] = {"name": key, "moe": True, "params": [], "dp_mode": ParallelMode.EXPERT_DATA}