diff --git a/colossalai/amp/naive_amp/_fp16_optimizer.py b/colossalai/amp/naive_amp/_fp16_optimizer.py index 98bb1e639..7add4bc98 100644 --- a/colossalai/amp/naive_amp/_fp16_optimizer.py +++ b/colossalai/amp/naive_amp/_fp16_optimizer.py @@ -39,106 +39,6 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): that_.copy_(this_) -class DynamicGradScaler: - - def __init__(self, - initial_scale, - min_scale, - growth_factor, - backoff_factor, - growth_interval, - hysteresis, - max_scale: int = None, - verbose: bool = False): - """"Grad scaler with dynamic scale that gets adjusted - during training.""" - assert initial_scale > 0.0 - self._scale = torch.cuda.FloatTensor([initial_scale]) - - # Lower bound on the scale. - assert min_scale > 0.0 - assert min_scale <= initial_scale - self.min_scale = torch.cuda.FloatTensor([min_scale]) - # Growth and backoff factors for the scale. - assert growth_factor > 1.0 - self.growth_factor = torch.cuda.FloatTensor([growth_factor]) - assert backoff_factor < 1.0 - assert backoff_factor > 0.0 - self.backoff_factor = torch.cuda.FloatTensor([backoff_factor]) - # Interval over which if we don't see any inf/nan, - # we will scale the grad scale by the growth factor. - assert growth_interval > 0 - self.growth_interval = growth_interval - # Number of inf/nans we should see before scaling down - # the grad scale by the backoff factor. - assert hysteresis > 0 - self.hysteresis = hysteresis - if max_scale is not None: - assert max_scale > 1 and initial_scale <= max_scale - self._max_scale = max_scale - - # Trackers. - self._growth_tracker = 0 - self._hysteresis_tracker = self.hysteresis - - self._logger = get_dist_logger() - self.verbose = verbose - - @property - def scale(self): - return self._scale - - @property - def inv_scale(self): - return self._scale.double().reciprocal().float() - - def update(self, found_inf): - - # If we have an inf/nan, growth tracker is set to 0 - # and hysterisis tracker is reduced by 1. - if found_inf: - self._growth_tracker = 0 - self._hysteresis_tracker -= 1 - # Now if we are out of hysteresis count, scale down the loss. - if self._hysteresis_tracker <= 0: - self._scale = torch.max(self._scale * self.backoff_factor, self.min_scale) - if self.verbose: - self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0]) - else: - # If there is no nan/inf, increment the growth tracker. - self._growth_tracker += 1 - # If we have had enough consequitive intervals with no nan/inf: - if self._growth_tracker == self.growth_interval: - # Reset the tracker and hysteresis trackers, - self._growth_tracker = 0 - self._hysteresis_tracker = self.hysteresis - # and scale up the loss scale. - if self._max_scale is not None and self._scale >= self._max_scale: - if self.verbose: - self._logger.info( - f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', - ranks=[0]) - else: - self._scale = self._scale * self.growth_factor - if self.verbose: - self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}', - ranks=[0]) - - def state_dict(self): - state_dict = {} - state_dict['max_scale'] = self._max_scale - state_dict['scale'] = self._scale - state_dict['growth_tracker'] = self._growth_tracker - state_dict['hysteresis_tracker'] = self._hysteresis_tracker - return state_dict - - def load_state_dict(self, state_dict): - self._scale = state_dict['scale'].cuda(torch.cuda.current_device()) - self._growth_tracker = state_dict['growth_tracker'] - self._hysteresis_tracker = state_dict['hysteresis_tracker'] - self._max_scale = state_dict['max_scale'] - - class FP16Optimizer(Optimizer): """Float16 optimizer for fp16 and bf16 data types. @@ -284,7 +184,7 @@ class FP16Optimizer(Optimizer): # check for overflow for group in self._optimizer.param_groups: for p in group['params']: - if has_inf_or_nan(p.grad): + if p.grad is not None and has_inf_or_nan(p.grad): self._found_overflow.fill_(1.0) break @@ -316,9 +216,10 @@ class FP16Optimizer(Optimizer): # This only needs to be done for the float16 group. for fp16_param_group, fp32_master_param_group in zip(self._fp16_param_groups, self._fp32_master_param_groups): for fp16_param, fp32_param in zip(fp16_param_group, fp32_master_param_group): - fp32_param.grad = fp16_param.grad.float() - # clear unneeded grad on fp16 param - fp16_param.grad = None + if fp16_param.grad is not None: + fp32_param.grad = fp16_param.grad.float() + # clear unneeded grad on fp16 param + fp16_param.grad = None def _update_fp16_param_from_fp32_param(self): fp16_param_data = []