import math import warnings import torch.nn as nn from torch import Tensor def zeros_(): """Return the initializer filling the input Tensor with the scalar zeros""" def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return nn.init.zeros_(tensor) return initializer def ones_(): """Return the initializer filling the input Tensor with the scalar ones""" def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return nn.init.ones_(tensor) return initializer def uniform_(a: float = 0.0, b: float = 1.0): r"""Return the initializer filling the input Tensor with values drawn from the uniform distribution :math:`\mathcal{U}(a, b)`. Args: a (float): the lower bound of the uniform distribution. Defaults 0.0. b (float): the upper bound of the uniform distribution. Defaults 1.0. """ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return nn.init.uniform_(tensor, a, b) return initializer def normal_(mean: float = 0.0, std: float = 1.0): r"""Return the initializer filling the input Tensor with values drawn from the normal distribution .. math:: \mathcal{N}(\text{mean}, \text{std}^2) Args: mean (float): the mean of the normal distribution. Defaults 0.0. std (float): the standard deviation of the normal distribution. Defaults 1.0. """ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return nn.init.normal_(tensor, mean, std) return initializer def trunc_normal_(mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0): r"""Return the initializer filling the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. Args: mean (float): the mean of the normal distribution. Defaults 0.0. std (float): the standard deviation of the normal distribution. Defaults 1.0. a (float): the minimum cutoff value. Defaults -2.0. b (float): the maximum cutoff value. Defaults 2.0. """ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return nn.init.trunc_normal_(tensor, mean, std, a, b) return initializer def kaiming_uniform_(a=0, mode="fan_in", nonlinearity="leaky_relu"): r"""Return the initializer filling the input `Tensor` with values according to the method described in `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification` - He, K. et al. (2015), using a uniform distribution. The resulting tensor will have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where .. math:: \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan_mode}}} Also known as 'He initialization'. Args: a (int): the negative slope of the rectifier used after this layer (only used with ``'leaky_relu'``). mode (str, optional): either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` preserves the magnitude of the variance of the weights in the forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the backwards pass. nonlinearity (str, optional): the non-linear function (`nn.functional` name), recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). """ # adapted from torch.nn.init def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): if 0 in tensor.shape: warnings.warn("Initializing zero-element tensors is a no-op") return tensor if mode == "fan_in": assert fan_in is not None, "Fan_in is not provided." fan = fan_in elif mode == "fan_out": assert fan_out is not None, "Fan_out is not provided." fan = fan_out else: raise ValueError(f"Invalid initialization mode '{mode}'") std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan) bound = math.sqrt(3.0) * std return nn.init.uniform_(tensor, -bound, bound) return initializer def kaiming_normal_(a=0, mode="fan_in", nonlinearity="leaky_relu"): r"""Return the initializer filling the input `Tensor` with values according to the method described in `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification` - He, K. et al. (2015), using a normal distribution. The resulting tensor will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where .. math:: \text{std} = \frac{\text{gain}}{\sqrt{\text{fan_mode}}} Also known as 'He initialization'. Args: a (int): the negative slope of the rectifier used after this layer (only used with ``'leaky_relu'``). mode (str, optional): either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` preserves the magnitude of the variance of the weights in the forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the backwards pass. nonlinearity (str, optional): the non-linear function (`nn.functional` name), recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). """ # adapted from torch.nn.init def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): if 0 in tensor.shape: warnings.warn("Initializing zero-element tensors is a no-op") return tensor if mode == "fan_in": assert fan_in is not None, "Fan_in is not provided." fan = fan_in elif mode == "fan_out": assert fan_out is not None, "Fan_out is not provided." fan = fan_out else: raise ValueError(f"Invalid initialization mode '{mode}'") std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan) return nn.init.normal_(tensor, 0, std) return initializer def xavier_uniform_(a: float = math.sqrt(3.0), scale: float = 2.0, gain: float = 1.0): r"""Return the initializer filling the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform distribution. The resulting tensor will have values sampled from :math:`\mathcal{U}(-a, a)` where .. math:: a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}} Also known as 'Glorot initialization'. Args: a (float, optional): an optional scaling factor used to calculate uniform bounds from standard deviation. Defaults ``math.sqrt(3.)``. scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0. gain (float, optional): an optional scaling factor. Defaults 1.0. """ # adapted from torch.nn.init def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): assert fan_in is not None, "Fan_in is not provided." fan = fan_in if fan_out is not None: fan += fan_out std = gain * math.sqrt(scale / float(fan)) bound = a * std return nn.init.uniform_(tensor, -bound, bound) return initializer def xavier_normal_(scale: float = 2.0, gain: float = 1.0): r"""Return the initializer filling the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal distribution. The resulting tensor will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where .. math:: \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}} Also known as 'Glorot initialization'. Args: scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0. gain (float, optional): an optional scaling factor. Defaults 1.0. """ # adapted from torch.nn.init def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): assert fan_in is not None, "Fan_in is not provided." fan = fan_in if fan_out is not None: fan += fan_out std = gain * math.sqrt(scale / float(fan)) return nn.init.normal_(tensor, 0.0, std) return initializer def lecun_uniform_(): # adapted from jax.nn.initializers def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): assert fan_in is not None, "Fan_in is not provided." var = 1.0 / fan_in bound = math.sqrt(3 * var) return nn.init.uniform_(tensor, -bound, bound) return initializer def lecun_normal_(): # adapted from jax.nn.initializers def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): assert fan_in is not None, "Fan_in is not provided." std = math.sqrt(1.0 / fan_in) return nn.init.trunc_normal_(tensor, std=std / 0.87962566103423978) return initializer