mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
167 lines
6.9 KiB
167 lines
6.9 KiB
3 years ago
|
# Copyright 2019 The Microsoft DeepSpeed Team
|
||
|
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# Taken and modified for DeepSpeed from:
|
||
|
# https://github.com/NVIDIA/Megatron-LM/blob/master/fp16/loss_scaler.py
|
||
|
# Commit: 93ab4bea59dc5cbf97c079d313741866af4deac9
|
||
|
|
||
|
|
||
|
INITIAL_LOSS_SCALE = 'init_scale'
|
||
|
SCALE_WINDOW = 'scale_window'
|
||
|
DELAYED_SHIFT = 'delayed_shift'
|
||
|
MIN_LOSS_SCALE = 'min_scale'
|
||
|
|
||
|
|
||
|
# item() is a recent addition, so this helps with backward compatibility.
|
||
|
def to_python_float(t):
|
||
|
if hasattr(t, 'item'):
|
||
|
return t.item()
|
||
|
return t[0]
|
||
|
|
||
|
|
||
|
class LossScalerBase:
|
||
|
"""LossScalarBase
|
||
|
Base class for a loss scaler
|
||
|
"""
|
||
|
|
||
|
def __init__(self, cur_scale):
|
||
|
self.cur_scale = cur_scale
|
||
|
|
||
|
@property
|
||
|
def loss_scale(self):
|
||
|
return self.cur_scale
|
||
|
|
||
|
def scale_gradient(self, module, grad_in, grad_out):
|
||
|
return tuple(self.loss_scale * g for g in grad_in)
|
||
|
|
||
|
def update_scale(self, overflow):
|
||
|
pass
|
||
|
|
||
|
def backward(self, loss, retain_graph=False):
|
||
|
scaled_loss = loss * self.loss_scale
|
||
|
scaled_loss.backward(retain_graph=retain_graph)
|
||
|
|
||
|
|
||
|
class LossScaler(LossScalerBase):
|
||
|
"""
|
||
|
Class that manages a static loss scale. This class is intended to interact with
|
||
|
:class:`FP16_Optimizer`, and should not be directly manipulated by the user.
|
||
|
Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to
|
||
|
:class:`FP16_Optimizer`'s constructor.
|
||
|
Args:
|
||
|
scale (float, optional, default=1.0): The loss scale.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, scale=1):
|
||
|
super(LossScaler, self).__init__(scale)
|
||
|
|
||
|
# `params` is a list / generator of torch.Variable
|
||
|
def has_overflow(self, params):
|
||
|
return False
|
||
|
|
||
|
# `x` is a torch.Tensor
|
||
|
def _has_inf_or_nan(x):
|
||
|
return False
|
||
|
|
||
|
|
||
|
class DynamicLossScaler(LossScalerBase):
|
||
|
"""
|
||
|
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
|
||
|
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
|
||
|
:class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler`
|
||
|
operates, because the default options can be changed using the
|
||
|
the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.
|
||
|
Loss scaling is designed to combat the problem of underflowing gradients encountered at long
|
||
|
times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss
|
||
|
scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are
|
||
|
encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has
|
||
|
occurred.
|
||
|
:class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
|
||
|
and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
|
||
|
If a certain number of iterations occur without overflowing gradients detected,
|
||
|
:class:`DynamicLossScaler` increases the loss scale once more.
|
||
|
In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
|
||
|
always using the highest loss scale possible without incurring overflow.
|
||
|
Args:
|
||
|
init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
|
||
|
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
|
||
|
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
init_scale=2 ** 32,
|
||
|
scale_factor=2.,
|
||
|
scale_window=1000,
|
||
|
min_scale=1,
|
||
|
delayed_shift=1,
|
||
|
consecutive_hysteresis=False):
|
||
|
super(DynamicLossScaler, self).__init__(init_scale)
|
||
|
self.cur_iter = 0
|
||
|
self.last_overflow_iter = -1
|
||
|
self.scale_factor = scale_factor
|
||
|
self.scale_window = scale_window
|
||
|
self.min_scale = min_scale
|
||
|
self.delayed_shift = delayed_shift
|
||
|
self.cur_hysteresis = delayed_shift
|
||
|
self.consecutive_hysteresis = consecutive_hysteresis
|
||
|
|
||
|
# `params` is a list / generator of torch.Variable
|
||
|
def has_overflow_serial(self, params):
|
||
|
for p in params:
|
||
|
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
|
||
|
return True
|
||
|
|
||
|
return False
|
||
|
|
||
|
# `x` is a torch.Tensor
|
||
|
@staticmethod
|
||
|
def _has_inf_or_nan(x):
|
||
|
try:
|
||
|
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
|
||
|
# Pytorch's .sum() creates a one-element tensor of the same type as x
|
||
|
# (which is true for some recent version of pytorch).
|
||
|
cpu_sum = float(x.float().sum())
|
||
|
# More efficient version that can be used if .sum() returns a Python scalar
|
||
|
# cpu_sum = float(x.sum())
|
||
|
except RuntimeError as instance:
|
||
|
# We want to check if inst is actually an overflow exception.
|
||
|
# RuntimeError could come from a different error.
|
||
|
# If so, we still want the exception to propagate.
|
||
|
if "value cannot be converted" not in instance.args[0]:
|
||
|
raise
|
||
|
return True
|
||
|
else:
|
||
|
if cpu_sum in [float('inf'), -float('inf')] or cpu_sum != cpu_sum:
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
# `overflow` is boolean indicating whether the gradient overflowed
|
||
|
def update_scale(self, overflow):
|
||
|
if overflow:
|
||
|
# self.cur_scale /= self.scale_factor
|
||
|
if self.delayed_shift == 1 or self.cur_hysteresis == 1:
|
||
|
self.cur_scale = max(
|
||
|
self.cur_scale / self.scale_factor, self.min_scale)
|
||
|
else:
|
||
|
self.cur_hysteresis -= 1
|
||
|
self.last_overflow_iter = self.cur_iter
|
||
|
else:
|
||
|
if self.consecutive_hysteresis:
|
||
|
self.cur_hysteresis = self.delayed_shift
|
||
|
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
|
||
|
if not self.consecutive_hysteresis:
|
||
|
self.cur_hysteresis = self.delayed_shift
|
||
|
self.cur_scale *= self.scale_factor
|
||
|
self.cur_iter += 1
|