change the scale position for latent moe_loss

pull/182/head
zhanglei 2023-08-23 13:25:20 +08:00
parent 3a3ca71459
commit 72e3b1afd5
2 changed files with 2 additions and 19 deletions

View File

@ -335,6 +335,8 @@ class PipelineScheduler(BaseScheduler):
if output_obj_grad is None: if output_obj_grad is None:
engine.backward(output_obj + moe_loss) engine.backward(output_obj + moe_loss)
else: else:
# scale the latent loss
moe_loss = moe_loss * engine.optimizer.loss_scale
engine.backward_by_grad([output_obj, moe_loss], [output_obj_grad, None]) engine.backward_by_grad([output_obj, moe_loss], [output_obj_grad, None])
# Collect the grad of the input_obj. # Collect the grad of the input_obj.

View File

@ -498,25 +498,6 @@ class HybridZeroOptimizer(BaseOptimizer):
# Gradients may not be fully synchronized here. # Gradients may not be fully synchronized here.
def backward_by_grad(self, tensor, grad):
if isinstance(tensor, list) and isinstance(grad, list):
tensors = []
grads = []
for _t, _g in zip(tensor, grad):
# scale the latent loss for moe pipeline
if self._is_latent_loss(_t, _g):
_t = self.loss_scale * _t
tensors.append(_t)
grads.append(_g)
torch.autograd.backward(tensors=tensors, grad_tensors=grads)
else:
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
def _is_latent_loss(self, tensor, grad=None):
if tensor is not None and grad is None:
return tensor.numel() == 1
return False
def _compute_norm_with_stage( def _compute_norm_with_stage(
self, self,
group_id: int = 0, group_id: int = 0,