|
|
|
@ -39,106 +39,6 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
|
|
|
|
|
that_.copy_(this_)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DynamicGradScaler:
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
initial_scale,
|
|
|
|
|
min_scale,
|
|
|
|
|
growth_factor,
|
|
|
|
|
backoff_factor,
|
|
|
|
|
growth_interval,
|
|
|
|
|
hysteresis,
|
|
|
|
|
max_scale: int = None,
|
|
|
|
|
verbose: bool = False):
|
|
|
|
|
""""Grad scaler with dynamic scale that gets adjusted
|
|
|
|
|
during training."""
|
|
|
|
|
assert initial_scale > 0.0
|
|
|
|
|
self._scale = torch.cuda.FloatTensor([initial_scale])
|
|
|
|
|
|
|
|
|
|
# Lower bound on the scale.
|
|
|
|
|
assert min_scale > 0.0
|
|
|
|
|
assert min_scale <= initial_scale
|
|
|
|
|
self.min_scale = torch.cuda.FloatTensor([min_scale])
|
|
|
|
|
# Growth and backoff factors for the scale.
|
|
|
|
|
assert growth_factor > 1.0
|
|
|
|
|
self.growth_factor = torch.cuda.FloatTensor([growth_factor])
|
|
|
|
|
assert backoff_factor < 1.0
|
|
|
|
|
assert backoff_factor > 0.0
|
|
|
|
|
self.backoff_factor = torch.cuda.FloatTensor([backoff_factor])
|
|
|
|
|
# Interval over which if we don't see any inf/nan,
|
|
|
|
|
# we will scale the grad scale by the growth factor.
|
|
|
|
|
assert growth_interval > 0
|
|
|
|
|
self.growth_interval = growth_interval
|
|
|
|
|
# Number of inf/nans we should see before scaling down
|
|
|
|
|
# the grad scale by the backoff factor.
|
|
|
|
|
assert hysteresis > 0
|
|
|
|
|
self.hysteresis = hysteresis
|
|
|
|
|
if max_scale is not None:
|
|
|
|
|
assert max_scale > 1 and initial_scale <= max_scale
|
|
|
|
|
self._max_scale = max_scale
|
|
|
|
|
|
|
|
|
|
# Trackers.
|
|
|
|
|
self._growth_tracker = 0
|
|
|
|
|
self._hysteresis_tracker = self.hysteresis
|
|
|
|
|
|
|
|
|
|
self._logger = get_dist_logger()
|
|
|
|
|
self.verbose = verbose
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def scale(self):
|
|
|
|
|
return self._scale
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def inv_scale(self):
|
|
|
|
|
return self._scale.double().reciprocal().float()
|
|
|
|
|
|
|
|
|
|
def update(self, found_inf):
|
|
|
|
|
|
|
|
|
|
# If we have an inf/nan, growth tracker is set to 0
|
|
|
|
|
# and hysterisis tracker is reduced by 1.
|
|
|
|
|
if found_inf:
|
|
|
|
|
self._growth_tracker = 0
|
|
|
|
|
self._hysteresis_tracker -= 1
|
|
|
|
|
# Now if we are out of hysteresis count, scale down the loss.
|
|
|
|
|
if self._hysteresis_tracker <= 0:
|
|
|
|
|
self._scale = torch.max(self._scale * self.backoff_factor, self.min_scale)
|
|
|
|
|
if self.verbose:
|
|
|
|
|
self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0])
|
|
|
|
|
else:
|
|
|
|
|
# If there is no nan/inf, increment the growth tracker.
|
|
|
|
|
self._growth_tracker += 1
|
|
|
|
|
# If we have had enough consequitive intervals with no nan/inf:
|
|
|
|
|
if self._growth_tracker == self.growth_interval:
|
|
|
|
|
# Reset the tracker and hysteresis trackers,
|
|
|
|
|
self._growth_tracker = 0
|
|
|
|
|
self._hysteresis_tracker = self.hysteresis
|
|
|
|
|
# and scale up the loss scale.
|
|
|
|
|
if self._max_scale is not None and self._scale >= self._max_scale:
|
|
|
|
|
if self.verbose:
|
|
|
|
|
self._logger.info(
|
|
|
|
|
f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed',
|
|
|
|
|
ranks=[0])
|
|
|
|
|
else:
|
|
|
|
|
self._scale = self._scale * self.growth_factor
|
|
|
|
|
if self.verbose:
|
|
|
|
|
self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}',
|
|
|
|
|
ranks=[0])
|
|
|
|
|
|
|
|
|
|
def state_dict(self):
|
|
|
|
|
state_dict = {}
|
|
|
|
|
state_dict['max_scale'] = self._max_scale
|
|
|
|
|
state_dict['scale'] = self._scale
|
|
|
|
|
state_dict['growth_tracker'] = self._growth_tracker
|
|
|
|
|
state_dict['hysteresis_tracker'] = self._hysteresis_tracker
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
def load_state_dict(self, state_dict):
|
|
|
|
|
self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
|
|
|
|
|
self._growth_tracker = state_dict['growth_tracker']
|
|
|
|
|
self._hysteresis_tracker = state_dict['hysteresis_tracker']
|
|
|
|
|
self._max_scale = state_dict['max_scale']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FP16Optimizer(Optimizer):
|
|
|
|
|
"""Float16 optimizer for fp16 and bf16 data types.
|
|
|
|
|
|
|
|
|
@ -284,7 +184,7 @@ class FP16Optimizer(Optimizer):
|
|
|
|
|
# check for overflow
|
|
|
|
|
for group in self._optimizer.param_groups:
|
|
|
|
|
for p in group['params']:
|
|
|
|
|
if has_inf_or_nan(p.grad):
|
|
|
|
|
if p.grad is not None and has_inf_or_nan(p.grad):
|
|
|
|
|
self._found_overflow.fill_(1.0)
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
@ -316,9 +216,10 @@ class FP16Optimizer(Optimizer):
|
|
|
|
|
# This only needs to be done for the float16 group.
|
|
|
|
|
for fp16_param_group, fp32_master_param_group in zip(self._fp16_param_groups, self._fp32_master_param_groups):
|
|
|
|
|
for fp16_param, fp32_param in zip(fp16_param_group, fp32_master_param_group):
|
|
|
|
|
fp32_param.grad = fp16_param.grad.float()
|
|
|
|
|
# clear unneeded grad on fp16 param
|
|
|
|
|
fp16_param.grad = None
|
|
|
|
|
if fp16_param.grad is not None:
|
|
|
|
|
fp32_param.grad = fp16_param.grad.float()
|
|
|
|
|
# clear unneeded grad on fp16 param
|
|
|
|
|
fp16_param.grad = None
|
|
|
|
|
|
|
|
|
|
def _update_fp16_param_from_fp32_param(self):
|
|
|
|
|
fp16_param_data = []
|
|
|
|
|