import math
import warnings

import torch.nn as nn
from torch import Tensor


def zeros_():
    """Return the initializer filling the input Tensor with the scalar zeros"""

    def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
        return nn.init.zeros_(tensor)

    return initializer


def ones_():
    """Return the initializer filling the input Tensor with the scalar ones"""

    def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
        return nn.init.ones_(tensor)

    return initializer


def uniform_(a: float = 0.0, b: float = 1.0):
    r"""Return the initializer filling the input Tensor with values drawn from the uniform
    distribution :math:`\mathcal{U}(a, b)`.

    Args:
        a (float): the lower bound of the uniform distribution. Defaults 0.0.
        b (float): the upper bound of the uniform distribution. Defaults 1.0.
    """

    def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
        return nn.init.uniform_(tensor, a, b)

    return initializer


def normal_(mean: float = 0.0, std: float = 1.0):
    r"""Return the initializer filling the input Tensor with values drawn from the normal distribution

     .. math::
        \mathcal{N}(\text{mean}, \text{std}^2)

    Args:
        mean (float): the mean of the normal distribution. Defaults 0.0.
        std (float): the standard deviation of the normal distribution. Defaults 1.0.
    """

    def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
        return nn.init.normal_(tensor, mean, std)

    return initializer


def trunc_normal_(mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0):
    r"""Return the initializer filling the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.

    Args:
        mean (float): the mean of the normal distribution. Defaults 0.0.
        std (float): the standard deviation of the normal distribution. Defaults 1.0.
        a (float): the minimum cutoff value. Defaults -2.0.
        b (float): the maximum cutoff value. Defaults 2.0.
    """

    def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
        return nn.init.trunc_normal_(tensor, mean, std, a, b)

    return initializer


def kaiming_uniform_(a=0, mode="fan_in", nonlinearity="leaky_relu"):
    r"""Return the initializer filling the input `Tensor` with values according to the method
    described in `Delving deep into rectifiers: Surpassing human-level
    performance on ImageNet classification` - He, K. et al. (2015), using a
    uniform distribution. The resulting tensor will have values sampled from
    :math:`\mathcal{U}(-\text{bound}, \text{bound})` where

    .. math::
        \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan_mode}}}

    Also known as 'He initialization'.

    Args:
        a (int): the negative slope of the rectifier used after this layer (only used with ``'leaky_relu'``).
        mode (str, optional): either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
                preserves the magnitude of the variance of the weights in the
                forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
                backwards pass.
        nonlinearity (str, optional): the non-linear function (`nn.functional` name),
                        recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
    """

    # adapted from torch.nn.init
    def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
        if 0 in tensor.shape:
            warnings.warn("Initializing zero-element tensors is a no-op")
            return tensor

        if mode == "fan_in":
            assert fan_in is not None, "Fan_in is not provided."
            fan = fan_in
        elif mode == "fan_out":
            assert fan_out is not None, "Fan_out is not provided."
            fan = fan_out
        else:
            raise ValueError(f"Invalid initialization mode '{mode}'")

        std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan)
        bound = math.sqrt(3.0) * std
        return nn.init.uniform_(tensor, -bound, bound)

    return initializer


def kaiming_normal_(a=0, mode="fan_in", nonlinearity="leaky_relu"):
    r"""Return the initializer filling the input `Tensor` with values according to the method
    described in `Delving deep into rectifiers: Surpassing human-level
    performance on ImageNet classification` - He, K. et al. (2015), using a
    normal distribution. The resulting tensor will have values sampled from
    :math:`\mathcal{N}(0, \text{std}^2)` where

    .. math::
        \text{std} = \frac{\text{gain}}{\sqrt{\text{fan_mode}}}

    Also known as 'He initialization'.

    Args:
        a (int): the negative slope of the rectifier used after this layer (only used with ``'leaky_relu'``).
        mode (str, optional): either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
                preserves the magnitude of the variance of the weights in the
                forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
                backwards pass.
        nonlinearity (str, optional): the non-linear function (`nn.functional` name),
                        recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
    """

    # adapted from torch.nn.init
    def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
        if 0 in tensor.shape:
            warnings.warn("Initializing zero-element tensors is a no-op")
            return tensor

        if mode == "fan_in":
            assert fan_in is not None, "Fan_in is not provided."
            fan = fan_in
        elif mode == "fan_out":
            assert fan_out is not None, "Fan_out is not provided."
            fan = fan_out
        else:
            raise ValueError(f"Invalid initialization mode '{mode}'")

        std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan)
        return nn.init.normal_(tensor, 0, std)

    return initializer


def xavier_uniform_(a: float = math.sqrt(3.0), scale: float = 2.0, gain: float = 1.0):
    r"""Return the initializer filling the input `Tensor` with values according to the method
    described in `Understanding the difficulty of training deep feedforward
    neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform
    distribution. The resulting tensor will have values sampled from
    :math:`\mathcal{U}(-a, a)` where

    .. math::
        a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}}

    Also known as 'Glorot initialization'.

    Args:
        a (float, optional): an optional scaling factor used to calculate uniform
            bounds from standard deviation. Defaults ``math.sqrt(3.)``.
        scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0.
        gain (float, optional): an optional scaling factor. Defaults 1.0.
    """

    # adapted from torch.nn.init
    def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
        assert fan_in is not None, "Fan_in is not provided."

        fan = fan_in
        if fan_out is not None:
            fan += fan_out

        std = gain * math.sqrt(scale / float(fan))
        bound = a * std
        return nn.init.uniform_(tensor, -bound, bound)

    return initializer


def xavier_normal_(scale: float = 2.0, gain: float = 1.0):
    r"""Return the initializer filling the input `Tensor` with values according to the method
    described in `Understanding the difficulty of training deep feedforward
    neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal
    distribution. The resulting tensor will have values sampled from
    :math:`\mathcal{N}(0, \text{std}^2)` where

    .. math::
        \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}}

    Also known as 'Glorot initialization'.

    Args:
        scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0.
        gain (float, optional): an optional scaling factor. Defaults 1.0.
    """

    # adapted from torch.nn.init
    def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
        assert fan_in is not None, "Fan_in is not provided."

        fan = fan_in
        if fan_out is not None:
            fan += fan_out

        std = gain * math.sqrt(scale / float(fan))

        return nn.init.normal_(tensor, 0.0, std)

    return initializer


def lecun_uniform_():
    # adapted from jax.nn.initializers
    def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
        assert fan_in is not None, "Fan_in is not provided."

        var = 1.0 / fan_in
        bound = math.sqrt(3 * var)
        return nn.init.uniform_(tensor, -bound, bound)

    return initializer


def lecun_normal_():
    # adapted from jax.nn.initializers
    def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
        assert fan_in is not None, "Fan_in is not provided."

        std = math.sqrt(1.0 / fan_in)
        return nn.init.trunc_normal_(tensor, std=std / 0.87962566103423978)

    return initializer