fix the bug that missing scale the latent moe loss

pull/182/head
zhanglei 2023-08-23 10:53:36 +08:00
parent 12b739e83b
commit 3f32ee31bb
1 changed files with 14 additions and 0 deletions

View File

@ -498,6 +498,20 @@ class HybridZeroOptimizer(BaseOptimizer):
# Gradients may not be fully synchronized here. # Gradients may not be fully synchronized here.
def backward_by_grad(self, tensor, grad):
if isinstance(tensor, list) and isinstance(grad, list):
tensors = []
grads = []
for _t, _g in zip(tensor, grad):
# scale the latent loss whose grad is None and tensor is in computing graph
if _t is not None and _g is None:
_t = self.loss_scale * _t
tensors.append(_t)
grads.append(_g)
torch.autograd.backward(tensors=tensors, grad_tensors=grads)
else:
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
def _compute_norm_with_stage( def _compute_norm_with_stage(
self, self,
group_id: int = 0, group_id: int = 0,