|
|
@ -140,6 +140,10 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|
|
|
|
|
|
|
|
|
|
|
return self._found_overflow.item() > 0
|
|
|
|
return self._found_overflow.item() > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _clear_global_norm(self) -> None:
|
|
|
|
|
|
|
|
for c16 in self.chunk16_set:
|
|
|
|
|
|
|
|
c16.l2_norm = None
|
|
|
|
|
|
|
|
|
|
|
|
def _calc_global_norm(self) -> float:
|
|
|
|
def _calc_global_norm(self) -> float:
|
|
|
|
norm_sqr: float = 0.0
|
|
|
|
norm_sqr: float = 0.0
|
|
|
|
group_to_norm = dict()
|
|
|
|
group_to_norm = dict()
|
|
|
@ -201,6 +205,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|
|
|
self.optim_state = OptimState.UNSCALED # no need to unscale grad
|
|
|
|
self.optim_state = OptimState.UNSCALED # no need to unscale grad
|
|
|
|
self.grad_scaler.update(found_inf) # update gradient scaler
|
|
|
|
self.grad_scaler.update(found_inf) # update gradient scaler
|
|
|
|
self._logger.info(f'Found overflow. Skip step')
|
|
|
|
self._logger.info(f'Found overflow. Skip step')
|
|
|
|
|
|
|
|
self._clear_global_norm() # clear recorded norm
|
|
|
|
self.zero_grad() # reset all gradients
|
|
|
|
self.zero_grad() # reset all gradients
|
|
|
|
self._update_fp16_params()
|
|
|
|
self._update_fp16_params()
|
|
|
|
return
|
|
|
|
return
|
|
|
|