diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index dd82afc..7f858a8 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -503,8 +503,8 @@ class HybridZeroOptimizer(BaseOptimizer): tensors = [] grads = [] for _t, _g in zip(tensor, grad): - # scale the latent loss whose grad is None and tensor is in computing graph - if _t is not None and _g is None: + # scale the latent loss for moe pipeline + if self._is_latent_loss(_t, _g): _t = self.loss_scale * _t tensors.append(_t) grads.append(_g) @@ -512,6 +512,11 @@ class HybridZeroOptimizer(BaseOptimizer): else: torch.autograd.backward(tensors=tensor, grad_tensors=grad) + def _is_latent_loss(self, tensor, grad=None): + if tensor is not None and grad is None: + return tensor.numel() == 1 + return False + def _compute_norm_with_stage( self, group_id: int = 0,