[hotfix] add norm clearing for the overflow step (#2416)

pull/2424/head
HELSON 2 years ago committed by GitHub
parent 57b6157b6c
commit dddacd2d2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -140,6 +140,10 @@ class ZeroOptimizer(ColossalaiOptimizer):
return self._found_overflow.item() > 0
def _clear_global_norm(self) -> None:
for c16 in self.chunk16_set:
c16.l2_norm = None
def _calc_global_norm(self) -> float:
norm_sqr: float = 0.0
group_to_norm = dict()
@ -201,6 +205,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.optim_state = OptimState.UNSCALED # no need to unscale grad
self.grad_scaler.update(found_inf) # update gradient scaler
self._logger.info(f'Found overflow. Skip step')
self._clear_global_norm() # clear recorded norm
self.zero_grad() # reset all gradients
self._update_fp16_params()
return

Loading…
Cancel
Save