diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index c43611f..3e43b3a 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -574,7 +574,10 @@ class HybridZeroOptimizer(BaseOptimizer): # compute norm for gradients in the before bucket groups_norms = [] for group_id in range(self.num_param_groups): - groups_norms.append(self._compute_norm_with_stage(group_id=group_id)) + if self._is_moe_group(self.optim.param_groups[group_id]): + groups_norms.append([]) + else: + groups_norms.append(self._compute_norm_with_stage(group_id=group_id)) # clear reduced grads if self._overlap_sync_grad: