fixed fp16 optimizer none grad bug (#432)

pull/433/head
Frank Lee 3 years ago committed by GitHub
parent fce9432f08
commit 14a7094243
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -39,106 +39,6 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
that_.copy_(this_)
class DynamicGradScaler:
def __init__(self,
initial_scale,
min_scale,
growth_factor,
backoff_factor,
growth_interval,
hysteresis,
max_scale: int = None,
verbose: bool = False):
""""Grad scaler with dynamic scale that gets adjusted
during training."""
assert initial_scale > 0.0
self._scale = torch.cuda.FloatTensor([initial_scale])
# Lower bound on the scale.
assert min_scale > 0.0
assert min_scale <= initial_scale
self.min_scale = torch.cuda.FloatTensor([min_scale])
# Growth and backoff factors for the scale.
assert growth_factor > 1.0
self.growth_factor = torch.cuda.FloatTensor([growth_factor])
assert backoff_factor < 1.0
assert backoff_factor > 0.0
self.backoff_factor = torch.cuda.FloatTensor([backoff_factor])
# Interval over which if we don't see any inf/nan,
# we will scale the grad scale by the growth factor.
assert growth_interval > 0
self.growth_interval = growth_interval
# Number of inf/nans we should see before scaling down
# the grad scale by the backoff factor.
assert hysteresis > 0
self.hysteresis = hysteresis
if max_scale is not None:
assert max_scale > 1 and initial_scale <= max_scale
self._max_scale = max_scale
# Trackers.
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
self._logger = get_dist_logger()
self.verbose = verbose
@property
def scale(self):
return self._scale
@property
def inv_scale(self):
return self._scale.double().reciprocal().float()
def update(self, found_inf):
# If we have an inf/nan, growth tracker is set to 0
# and hysterisis tracker is reduced by 1.
if found_inf:
self._growth_tracker = 0
self._hysteresis_tracker -= 1
# Now if we are out of hysteresis count, scale down the loss.
if self._hysteresis_tracker <= 0:
self._scale = torch.max(self._scale * self.backoff_factor, self.min_scale)
if self.verbose:
self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0])
else:
# If there is no nan/inf, increment the growth tracker.
self._growth_tracker += 1
# If we have had enough consequitive intervals with no nan/inf:
if self._growth_tracker == self.growth_interval:
# Reset the tracker and hysteresis trackers,
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
# and scale up the loss scale.
if self._max_scale is not None and self._scale >= self._max_scale:
if self.verbose:
self._logger.info(
f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed',
ranks=[0])
else:
self._scale = self._scale * self.growth_factor
if self.verbose:
self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}',
ranks=[0])
def state_dict(self):
state_dict = {}
state_dict['max_scale'] = self._max_scale
state_dict['scale'] = self._scale
state_dict['growth_tracker'] = self._growth_tracker
state_dict['hysteresis_tracker'] = self._hysteresis_tracker
return state_dict
def load_state_dict(self, state_dict):
self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
self._growth_tracker = state_dict['growth_tracker']
self._hysteresis_tracker = state_dict['hysteresis_tracker']
self._max_scale = state_dict['max_scale']
class FP16Optimizer(Optimizer):
"""Float16 optimizer for fp16 and bf16 data types.
@ -284,7 +184,7 @@ class FP16Optimizer(Optimizer):
# check for overflow
for group in self._optimizer.param_groups:
for p in group['params']:
if has_inf_or_nan(p.grad):
if p.grad is not None and has_inf_or_nan(p.grad):
self._found_overflow.fill_(1.0)
break
@ -316,9 +216,10 @@ class FP16Optimizer(Optimizer):
# This only needs to be done for the float16 group.
for fp16_param_group, fp32_master_param_group in zip(self._fp16_param_groups, self._fp32_master_param_groups):
for fp16_param, fp32_param in zip(fp16_param_group, fp32_master_param_group):
fp32_param.grad = fp16_param.grad.float()
# clear unneeded grad on fp16 param
fp16_param.grad = None
if fp16_param.grad is not None:
fp32_param.grad = fp16_param.grad.float()
# clear unneeded grad on fp16 param
fp16_param.grad = None
def _update_fp16_param_from_fp32_param(self):
fp16_param_data = []

Loading…
Cancel
Save