mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/nn/init.py code style (#1292)
parent
556b9b7e1a
commit
2dd4d556fb
|
@ -7,6 +7,7 @@ import torch.nn as nn
|
|||
|
||||
def zeros_():
|
||||
"""Return the initializer filling the input Tensor with the scalar zeros"""
|
||||
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
return nn.init.zeros_(tensor)
|
||||
|
||||
|
@ -15,6 +16,7 @@ def zeros_():
|
|||
|
||||
def ones_():
|
||||
"""Return the initializer filling the input Tensor with the scalar ones"""
|
||||
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
return nn.init.ones_(tensor)
|
||||
|
||||
|
@ -46,6 +48,7 @@ def normal_(mean: float = 0., std: float = 1.):
|
|||
mean (float): the mean of the normal distribution. Defaults 0.0.
|
||||
std (float): the standard deviation of the normal distribution. Defaults 1.0.
|
||||
"""
|
||||
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
return nn.init.normal_(tensor, mean, std)
|
||||
|
||||
|
@ -66,6 +69,7 @@ def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float =
|
|||
a (float): the minimum cutoff value. Defaults -2.0.
|
||||
b (float): the maximum cutoff value. Defaults 2.0.
|
||||
"""
|
||||
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
return nn.init.trunc_normal_(tensor, mean, std, a, b)
|
||||
|
||||
|
@ -93,6 +97,7 @@ def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
|||
nonlinearity (str, optional): the non-linear function (`nn.functional` name),
|
||||
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
|
||||
"""
|
||||
|
||||
# adapted from torch.nn.init
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
if 0 in tensor.shape:
|
||||
|
@ -136,6 +141,7 @@ def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
|||
nonlinearity (str, optional): the non-linear function (`nn.functional` name),
|
||||
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
|
||||
"""
|
||||
|
||||
# adapted from torch.nn.init
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
if 0 in tensor.shape:
|
||||
|
@ -175,6 +181,7 @@ def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1
|
|||
scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0.
|
||||
gain (float, optional): an optional scaling factor. Defaults 1.0.
|
||||
"""
|
||||
|
||||
# adapted from torch.nn.init
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
assert fan_in is not None, 'Fan_in is not provided.'
|
||||
|
@ -206,6 +213,7 @@ def xavier_normal_(scale: float = 2., gain: float = 1.):
|
|||
scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0.
|
||||
gain (float, optional): an optional scaling factor. Defaults 1.0.
|
||||
"""
|
||||
|
||||
# adapted from torch.nn.init
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
assert fan_in is not None, 'Fan_in is not provided.'
|
||||
|
@ -241,4 +249,4 @@ def lecun_normal_():
|
|||
std = math.sqrt(1.0 / fan_in)
|
||||
return nn.init.trunc_normal_(tensor, std=std / .87962566103423978)
|
||||
|
||||
return initializer
|
||||
return initializer
|
||||
|
|
Loading…
Reference in New Issue