diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 0d5ce4b..dfe0a4a 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -522,6 +522,19 @@ class HybridZeroOptimizer(BaseOptimizer): return norm + def _compute_norm_with_moe_group(self, group_id): + parameters = self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._zero_local_rank) + # wo do not get the average grad for moe parameters, so we have to constuct + # the gradients list hear. Maybe this can be optimized. + gradients = [p.grad for p in parameters] + norm = compute_norm( + gradients=gradients, + parameters=parameters, + last_stage=True, + ) + + return norm + def step(self, closure=None): """Performs a single optimization step. @@ -559,12 +572,14 @@ class HybridZeroOptimizer(BaseOptimizer): # compute norm for gradients in the last bucket total_norms = [] for group_id in range(self.num_param_groups): - total_norms.append( - self._compute_norm_with_stage( - group_id=group_id, last_bucket=True, last_stage=True, previous_norm=groups_norms[group_id] + if self._is_moe_group(self.optim.param_groups[group_id]): + total_norms.append(self._compute_norm_with_moe_group(group_id=group_id)) + else: + total_norms.append( + self._compute_norm_with_stage( + group_id=group_id, last_bucket=True, last_stage=True, previous_norm=groups_norms[group_id] + ) ) - ) - timer("sync_grad").start() self._sync_grad() timer("sync_grad").stop()