diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index dfe0a4a..dd82afc 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -498,6 +498,20 @@ class HybridZeroOptimizer(BaseOptimizer): # Gradients may not be fully synchronized here. + def backward_by_grad(self, tensor, grad): + if isinstance(tensor, list) and isinstance(grad, list): + tensors = [] + grads = [] + for _t, _g in zip(tensor, grad): + # scale the latent loss whose grad is None and tensor is in computing graph + if _t is not None and _g is None: + _t = self.loss_scale * _t + tensors.append(_t) + grads.append(_g) + torch.autograd.backward(tensors=tensors, grad_tensors=grads) + else: + torch.autograd.backward(tensors=tensor, grad_tensors=grad) + def _compute_norm_with_stage( self, group_id: int = 0,