diff --git a/colossalai/nn/optimizer/zero_optimizer.py b/colossalai/nn/optimizer/zero_optimizer.py index 2786d4496..7f9d2fe8f 100644 --- a/colossalai/nn/optimizer/zero_optimizer.py +++ b/colossalai/nn/optimizer/zero_optimizer.py @@ -140,6 +140,10 @@ class ZeroOptimizer(ColossalaiOptimizer): return self._found_overflow.item() > 0 + def _clear_global_norm(self) -> None: + for c16 in self.chunk16_set: + c16.l2_norm = None + def _calc_global_norm(self) -> float: norm_sqr: float = 0.0 group_to_norm = dict() @@ -201,6 +205,7 @@ class ZeroOptimizer(ColossalaiOptimizer): self.optim_state = OptimState.UNSCALED # no need to unscale grad self.grad_scaler.update(found_inf) # update gradient scaler self._logger.info(f'Found overflow. Skip step') + self._clear_global_norm() # clear recorded norm self.zero_grad() # reset all gradients self._update_fp16_params() return