mirror of https://github.com/hpcaitech/ColossalAI
244 lines
9.3 KiB
Python
244 lines
9.3 KiB
Python
import math
|
|
import warnings
|
|
|
|
from torch import Tensor
|
|
import torch.nn as nn
|
|
|
|
|
|
def zeros_():
|
|
"""Return the initializer filling the input Tensor with the scalar zeros"""
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
return nn.init.zeros_(tensor)
|
|
|
|
return initializer
|
|
|
|
|
|
def ones_():
|
|
"""Return the initializer filling the input Tensor with the scalar ones"""
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
return nn.init.ones_(tensor)
|
|
|
|
return initializer
|
|
|
|
|
|
def uniform_(a: float = 0., b: float = 1.):
|
|
r"""Return the initializer filling the input Tensor with values drawn from the uniform
|
|
distribution :math:`\mathcal{U}(a, b)`.
|
|
|
|
Args:
|
|
a (float): the lower bound of the uniform distribution. Defaults 0.0.
|
|
b (float): the upper bound of the uniform distribution. Defaults 1.0.
|
|
"""
|
|
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
return nn.init.uniform_(tensor, a, b)
|
|
|
|
return initializer
|
|
|
|
|
|
def normal_(mean: float = 0., std: float = 1.):
|
|
r"""Return the initializer filling the input Tensor with values drawn from the normal distribution
|
|
|
|
.. math::
|
|
\mathcal{N}(\text{mean}, \text{std}^2)
|
|
|
|
Args:
|
|
mean (float): the mean of the normal distribution. Defaults 0.0.
|
|
std (float): the standard deviation of the normal distribution. Defaults 1.0.
|
|
"""
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
return nn.init.normal_(tensor, mean, std)
|
|
|
|
return initializer
|
|
|
|
|
|
def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float = 2.):
|
|
r"""Return the initializer filling the input Tensor with values drawn from a truncated
|
|
normal distribution. The values are effectively drawn from the
|
|
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
|
with values outside :math:`[a, b]` redrawn until they are within
|
|
the bounds. The method used for generating the random values works
|
|
best when :math:`a \leq \text{mean} \leq b`.
|
|
|
|
Args:
|
|
mean (float): the mean of the normal distribution. Defaults 0.0.
|
|
std (float): the standard deviation of the normal distribution. Defaults 1.0.
|
|
a (float): the minimum cutoff value. Defaults -2.0.
|
|
b (float): the maximum cutoff value. Defaults 2.0.
|
|
"""
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
return nn.init.trunc_normal_(tensor, mean, std, a, b)
|
|
|
|
return initializer
|
|
|
|
|
|
def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
|
r"""Return the initializer filling the input `Tensor` with values according to the method
|
|
described in `Delving deep into rectifiers: Surpassing human-level
|
|
performance on ImageNet classification` - He, K. et al. (2015), using a
|
|
uniform distribution. The resulting tensor will have values sampled from
|
|
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
|
|
|
|
.. math::
|
|
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan_mode}}}
|
|
|
|
Also known as 'He initialization'.
|
|
|
|
Args:
|
|
a (int): the negative slope of the rectifier used after this layer (only used with ``'leaky_relu'``).
|
|
mode (str, optional): either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
|
|
preserves the magnitude of the variance of the weights in the
|
|
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
|
|
backwards pass.
|
|
nonlinearity (str, optional): the non-linear function (`nn.functional` name),
|
|
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
|
|
"""
|
|
# adapted from torch.nn.init
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
if 0 in tensor.shape:
|
|
warnings.warn("Initializing zero-element tensors is a no-op")
|
|
return tensor
|
|
|
|
if mode == 'fan_in':
|
|
assert fan_in is not None, 'Fan_in is not provided.'
|
|
fan = fan_in
|
|
elif mode == 'fan_out':
|
|
assert fan_out is not None, 'Fan_out is not provided.'
|
|
fan = fan_out
|
|
else:
|
|
raise ValueError(f'Invalid initialization mode \'{mode}\'')
|
|
|
|
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan)
|
|
bound = math.sqrt(3.) * std
|
|
return nn.init.uniform_(tensor, -bound, bound)
|
|
|
|
return initializer
|
|
|
|
|
|
def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
|
r"""Return the initializer filling the input `Tensor` with values according to the method
|
|
described in `Delving deep into rectifiers: Surpassing human-level
|
|
performance on ImageNet classification` - He, K. et al. (2015), using a
|
|
normal distribution. The resulting tensor will have values sampled from
|
|
:math:`\mathcal{N}(0, \text{std}^2)` where
|
|
|
|
.. math::
|
|
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan_mode}}}
|
|
|
|
Also known as 'He initialization'.
|
|
|
|
Args:
|
|
a (int): the negative slope of the rectifier used after this layer (only used with ``'leaky_relu'``).
|
|
mode (str, optional): either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
|
|
preserves the magnitude of the variance of the weights in the
|
|
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
|
|
backwards pass.
|
|
nonlinearity (str, optional): the non-linear function (`nn.functional` name),
|
|
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
|
|
"""
|
|
# adapted from torch.nn.init
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
if 0 in tensor.shape:
|
|
warnings.warn("Initializing zero-element tensors is a no-op")
|
|
return tensor
|
|
|
|
if mode == 'fan_in':
|
|
assert fan_in is not None, 'Fan_in is not provided.'
|
|
fan = fan_in
|
|
elif mode == 'fan_out':
|
|
assert fan_out is not None, 'Fan_out is not provided.'
|
|
fan = fan_out
|
|
else:
|
|
raise ValueError(f'Invalid initialization mode \'{mode}\'')
|
|
|
|
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan)
|
|
return nn.init.normal_(tensor, 0, std)
|
|
|
|
return initializer
|
|
|
|
|
|
def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1.):
|
|
r"""Return the initializer filling the input `Tensor` with values according to the method
|
|
described in `Understanding the difficulty of training deep feedforward
|
|
neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform
|
|
distribution. The resulting tensor will have values sampled from
|
|
:math:`\mathcal{U}(-a, a)` where
|
|
|
|
.. math::
|
|
a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}}
|
|
|
|
Also known as 'Glorot initialization'.
|
|
|
|
Args:
|
|
a (float, optional): an optional scaling factor used to calculate uniform
|
|
bounds from standard deviation. Defaults ``math.sqrt(3.)``.
|
|
scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0.
|
|
gain (float, optional): an optional scaling factor. Defaults 1.0.
|
|
"""
|
|
# adapted from torch.nn.init
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
assert fan_in is not None, 'Fan_in is not provided.'
|
|
|
|
fan = fan_in
|
|
if fan_out is not None:
|
|
fan += fan_out
|
|
|
|
std = gain * math.sqrt(scale / float(fan))
|
|
bound = a * std
|
|
return nn.init.uniform_(tensor, -bound, bound)
|
|
|
|
return initializer
|
|
|
|
|
|
def xavier_normal_(scale: float = 2., gain: float = 1.):
|
|
r"""Return the initializer filling the input `Tensor` with values according to the method
|
|
described in `Understanding the difficulty of training deep feedforward
|
|
neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal
|
|
distribution. The resulting tensor will have values sampled from
|
|
:math:`\mathcal{N}(0, \text{std}^2)` where
|
|
|
|
.. math::
|
|
\text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}}
|
|
|
|
Also known as 'Glorot initialization'.
|
|
|
|
Args:
|
|
scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0.
|
|
gain (float, optional): an optional scaling factor. Defaults 1.0.
|
|
"""
|
|
# adapted from torch.nn.init
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
assert fan_in is not None, 'Fan_in is not provided.'
|
|
|
|
fan = fan_in
|
|
if fan_out is not None:
|
|
fan += fan_out
|
|
|
|
std = gain * math.sqrt(scale / float(fan))
|
|
|
|
return nn.init.normal_(tensor, 0., std)
|
|
|
|
return initializer
|
|
|
|
|
|
def lecun_uniform_():
|
|
# adapted from jax.nn.initializers
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
assert fan_in is not None, 'Fan_in is not provided.'
|
|
|
|
var = 1.0 / fan_in
|
|
bound = math.sqrt(3 * var)
|
|
return nn.init.uniform_(tensor, -bound, bound)
|
|
|
|
return initializer
|
|
|
|
|
|
def lecun_normal_():
|
|
# adapted from jax.nn.initializers
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
assert fan_in is not None, 'Fan_in is not provided.'
|
|
|
|
std = math.sqrt(1.0 / fan_in)
|
|
return nn.init.trunc_normal_(tensor, std=std / .87962566103423978)
|
|
|
|
return initializer |