From 3f32ee31bb1d5747ef88f074c8581229b75dd732 Mon Sep 17 00:00:00 2001 From: zhanglei Date: Wed, 23 Aug 2023 10:53:36 +0800 Subject: [PATCH] fix the bug that missing scale the latent moe loss --- internlm/solver/optimizer/hybrid_zero_optim.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index dfe0a4a..dd82afc 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -498,6 +498,20 @@ class HybridZeroOptimizer(BaseOptimizer): # Gradients may not be fully synchronized here. + def backward_by_grad(self, tensor, grad): + if isinstance(tensor, list) and isinstance(grad, list): + tensors = [] + grads = [] + for _t, _g in zip(tensor, grad): + # scale the latent loss whose grad is None and tensor is in computing graph + if _t is not None and _g is None: + _t = self.loss_scale * _t + tensors.append(_t) + grads.append(_g) + torch.autograd.backward(tensors=tensors, grad_tensors=grads) + else: + torch.autograd.backward(tensors=tensor, grad_tensors=grad) + def _compute_norm_with_stage( self, group_id: int = 0,