From 08532dc20becec34ddc9b95c2645f5211ac45682 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Thu, 17 Aug 2023 17:07:31 +0800 Subject: [PATCH] fix bugs with merge --- .gitignore | 1 + .../solver/optimizer/hybrid_zero_optim.py | 24 ++++++++++--------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 8992a0f..055e7ad 100644 --- a/.gitignore +++ b/.gitignore @@ -132,6 +132,7 @@ runs_bak/ LLM_ALERT small_demo/ 7b_llama_nopp/ +test/ # Pytorch *.pth diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 1b1995c..0d5ce4b 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -276,12 +276,12 @@ class HybridZeroOptimizer(BaseOptimizer): self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)]) param_list = param_group["params"] - if "moe" in param_group.keys() and param_group["moe"]: + if self._is_moe_group(param_group): # just add current params to params_per_rank[_zero_local_rank] params_per_rank[self._zero_local_rank] = list(param_list) self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None) no_params_ranks = list(range(self._zero_world_size)) - no_params_ranks.pop(self._zero_world_size) + no_params_ranks.pop(self._zero_local_rank) else: sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) @@ -307,6 +307,9 @@ class HybridZeroOptimizer(BaseOptimizer): return params_per_rank, set(no_params_ranks) + def _is_moe_group(self, param_group): + return "moe" in param_group.keys() and param_group["moe"] + def _attach_reduction_hook(self): # we iterate over the fp16 params # on each param, we register a hook to its AccumulateGrad object @@ -568,11 +571,11 @@ class HybridZeroOptimizer(BaseOptimizer): return self._step(closure=closure, norms=total_norms) - def _get_norm_with_moe_layers(self, norm_groups): + def _get_norm_with_moe_layers(self, norm): # all_groups_norm_old = all_groups_norm # Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce pg = gpc.get_group(ParallelMode.DATA) - scaled_norm = norm_groups * 1.0 / float(gpc.get_world_size(ParallelMode.DATA)) + scaled_norm = norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA)) scaled_norm_tensor = torch.tensor( scaled_norm, device=self._fp32_flat_param_groups_of_current_rank[0].device, dtype=torch.float ) @@ -593,7 +596,6 @@ class HybridZeroOptimizer(BaseOptimizer): if -1 in norms: found_inf = True - loss_scale = float(self.loss_scale.item()) # backup if gpc.config.model.dtype is not torch.float32: self.grad_scaler.update(found_inf) @@ -638,17 +640,15 @@ class HybridZeroOptimizer(BaseOptimizer): # get the global norm global_norm_groups = [] if self._clip_grad_norm > 0: - for norm in norms: - global_norm_groups.append(norm**0.5) - - if self.has_moe: - global_norm = self._get_norm_with_moe_layers(global_norm) + for group_id in range(self.num_param_groups): + if self._is_moe_group(self.optim.param_groups[group_id]): + self._get_norm_with_moe_layers(norms[group_id]) + global_norm_groups.append(norms[group_id] ** 0.5) # the following operations are performed only on the rank to which parameters are assigned. if gpc.config.model.dtype is not torch.float32: if len(single_grad_partition_groups) != 0: self._unscale_and_clip_grads(single_grad_partition_groups, global_norm_groups, loss_scale) - # update the parameters timer("step").start() @@ -679,6 +679,8 @@ class HybridZeroOptimizer(BaseOptimizer): handles = [] for group_id in range(self.num_param_groups): + if self._is_moe_group(self.optim.param_groups[group_id]): + continue for rank in range(self._zero_world_size): # The following operations are performed only on the rank to which parameters are assigned. if rank not in self.param_group_no_params_ranks[group_id]: