From dddacd2d2c4d2416563fa4160d715d11a9a2a691 Mon Sep 17 00:00:00 2001 From: HELSON Date: Tue, 10 Jan 2023 15:43:06 +0800 Subject: [PATCH] [hotfix] add norm clearing for the overflow step (#2416) --- colossalai/nn/optimizer/zero_optimizer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/colossalai/nn/optimizer/zero_optimizer.py b/colossalai/nn/optimizer/zero_optimizer.py index 2786d4496..7f9d2fe8f 100644 --- a/colossalai/nn/optimizer/zero_optimizer.py +++ b/colossalai/nn/optimizer/zero_optimizer.py @@ -140,6 +140,10 @@ class ZeroOptimizer(ColossalaiOptimizer): return self._found_overflow.item() > 0 + def _clear_global_norm(self) -> None: + for c16 in self.chunk16_set: + c16.l2_norm = None + def _calc_global_norm(self) -> float: norm_sqr: float = 0.0 group_to_norm = dict() @@ -201,6 +205,7 @@ class ZeroOptimizer(ColossalaiOptimizer): self.optim_state = OptimState.UNSCALED # no need to unscale grad self.grad_scaler.update(found_inf) # update gradient scaler self._logger.info(f'Found overflow. Skip step') + self._clear_global_norm() # clear recorded norm self.zero_grad() # reset all gradients self._update_fp16_params() return