mirror of https://github.com/InternLM/InternLM
refactor code
parent
3f32ee31bb
commit
d1d21546d9
|
@ -503,8 +503,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
tensors = []
|
||||
grads = []
|
||||
for _t, _g in zip(tensor, grad):
|
||||
# scale the latent loss whose grad is None and tensor is in computing graph
|
||||
if _t is not None and _g is None:
|
||||
# scale the latent loss for moe pipeline
|
||||
if self._is_latent_loss(_t, _g):
|
||||
_t = self.loss_scale * _t
|
||||
tensors.append(_t)
|
||||
grads.append(_g)
|
||||
|
@ -512,6 +512,11 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
else:
|
||||
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
|
||||
|
||||
def _is_latent_loss(self, tensor, grad=None):
|
||||
if tensor is not None and grad is None:
|
||||
return tensor.numel() == 1
|
||||
return False
|
||||
|
||||
def _compute_norm_with_stage(
|
||||
self,
|
||||
group_id: int = 0,
|
||||
|
|
Loading…
Reference in New Issue