[NFC] polish colossalai/nn/init.py code style (#1292)

pull/1307/head
Ofey Chan 2022-07-13 10:51:55 +08:00 committed by GitHub
parent 556b9b7e1a
commit 2dd4d556fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 1 deletions

View File

@ -7,6 +7,7 @@ import torch.nn as nn
def zeros_():
"""Return the initializer filling the input Tensor with the scalar zeros"""
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.zeros_(tensor)
@ -15,6 +16,7 @@ def zeros_():
def ones_():
"""Return the initializer filling the input Tensor with the scalar ones"""
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.ones_(tensor)
@ -46,6 +48,7 @@ def normal_(mean: float = 0., std: float = 1.):
mean (float): the mean of the normal distribution. Defaults 0.0.
std (float): the standard deviation of the normal distribution. Defaults 1.0.
"""
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.normal_(tensor, mean, std)
@ -66,6 +69,7 @@ def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float =
a (float): the minimum cutoff value. Defaults -2.0.
b (float): the maximum cutoff value. Defaults 2.0.
"""
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.trunc_normal_(tensor, mean, std, a, b)
@ -93,6 +97,7 @@ def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
nonlinearity (str, optional): the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
"""
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
if 0 in tensor.shape:
@ -136,6 +141,7 @@ def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
nonlinearity (str, optional): the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
"""
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
if 0 in tensor.shape:
@ -175,6 +181,7 @@ def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1
scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0.
gain (float, optional): an optional scaling factor. Defaults 1.0.
"""
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.'
@ -206,6 +213,7 @@ def xavier_normal_(scale: float = 2., gain: float = 1.):
scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0.
gain (float, optional): an optional scaling factor. Defaults 1.0.
"""
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.'
@ -241,4 +249,4 @@ def lecun_normal_():
std = math.sqrt(1.0 / fan_in)
return nn.init.trunc_normal_(tensor, std=std / .87962566103423978)
return initializer
return initializer