refactor code

pull/182/head
zhanglei 2023-08-23 11:03:08 +08:00
parent 3f32ee31bb
commit d1d21546d9
1 changed files with 7 additions and 2 deletions

View File

@ -503,8 +503,8 @@ class HybridZeroOptimizer(BaseOptimizer):
tensors = [] tensors = []
grads = [] grads = []
for _t, _g in zip(tensor, grad): for _t, _g in zip(tensor, grad):
# scale the latent loss whose grad is None and tensor is in computing graph # scale the latent loss for moe pipeline
if _t is not None and _g is None: if self._is_latent_loss(_t, _g):
_t = self.loss_scale * _t _t = self.loss_scale * _t
tensors.append(_t) tensors.append(_t)
grads.append(_g) grads.append(_g)
@ -512,6 +512,11 @@ class HybridZeroOptimizer(BaseOptimizer):
else: else:
torch.autograd.backward(tensors=tensor, grad_tensors=grad) torch.autograd.backward(tensors=tensor, grad_tensors=grad)
def _is_latent_loss(self, tensor, grad=None):
if tensor is not None and grad is None:
return tensor.numel() == 1
return False
def _compute_norm_with_stage( def _compute_norm_with_stage(
self, self,
group_id: int = 0, group_id: int = 0,