mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
141 lines
4.2 KiB
141 lines
4.2 KiB
import math
|
|
import warnings
|
|
|
|
from torch import Tensor
|
|
import torch.nn as nn
|
|
|
|
|
|
def zeros_():
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
return nn.init.zeros_(tensor)
|
|
|
|
return initializer
|
|
|
|
|
|
def ones_():
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
return nn.init.ones_(tensor)
|
|
|
|
return initializer
|
|
|
|
|
|
def uniform_(a: float = 0., b: float = 1.):
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
return nn.init.uniform_(tensor, a, b)
|
|
|
|
return initializer
|
|
|
|
|
|
def normal_(mean: float = 0., std: float = 1.):
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
return nn.init.normal_(tensor, mean, std)
|
|
|
|
return initializer
|
|
|
|
|
|
def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float = 2.):
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
return nn.init.trunc_normal_(tensor, mean, std, a, b)
|
|
|
|
return initializer
|
|
|
|
|
|
def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
|
# adapted from torch.nn.init
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
if 0 in tensor.shape:
|
|
warnings.warn("Initializing zero-element tensors is a no-op")
|
|
return tensor
|
|
|
|
if mode == 'fan_in':
|
|
assert fan_in is not None, 'Fan_in is not provided.'
|
|
fan = fan_in
|
|
elif mode == 'fan_out':
|
|
assert fan_out is not None, 'Fan_out is not provided.'
|
|
fan = fan_out
|
|
else:
|
|
raise ValueError(f'Invalid initialization mode \'{mode}\'')
|
|
|
|
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan)
|
|
bound = math.sqrt(3.) * std
|
|
return nn.init.uniform_(tensor, -bound, bound)
|
|
|
|
return initializer
|
|
|
|
|
|
def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
|
# adapted from torch.nn.init
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
if 0 in tensor.shape:
|
|
warnings.warn("Initializing zero-element tensors is a no-op")
|
|
return tensor
|
|
|
|
if mode == 'fan_in':
|
|
assert fan_in is not None, 'Fan_in is not provided.'
|
|
fan = fan_in
|
|
elif mode == 'fan_out':
|
|
assert fan_out is not None, 'Fan_out is not provided.'
|
|
fan = fan_out
|
|
else:
|
|
raise ValueError(f'Invalid initialization mode \'{mode}\'')
|
|
|
|
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan)
|
|
return nn.init.normal_(tensor, 0, std)
|
|
|
|
return initializer
|
|
|
|
|
|
def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1.):
|
|
# adapted from torch.nn.init
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
assert fan_in is not None, 'Fan_in is not provided.'
|
|
|
|
fan = fan_in
|
|
if fan_out is not None:
|
|
fan += fan_out
|
|
|
|
std = gain * math.sqrt(scale / float(fan))
|
|
bound = a * std
|
|
return nn.init.uniform_(tensor, -bound, bound)
|
|
|
|
return initializer
|
|
|
|
|
|
def xavier_normal_(scale: float = 2., gain: float = 1.):
|
|
# adapted from torch.nn.init
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
assert fan_in is not None, 'Fan_in is not provided.'
|
|
|
|
fan = fan_in
|
|
if fan_out is not None:
|
|
fan += fan_out
|
|
|
|
std = gain * math.sqrt(scale / float(fan))
|
|
|
|
return nn.init.normal_(tensor, 0., std)
|
|
|
|
return initializer
|
|
|
|
|
|
def lecun_uniform_():
|
|
# adapted from jax.nn.initializers
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
assert fan_in is not None, 'Fan_in is not provided.'
|
|
|
|
var = 1.0 / fan_in
|
|
bound = math.sqrt(3 * var)
|
|
return nn.init.uniform_(tensor, -bound, bound)
|
|
|
|
return initializer
|
|
|
|
|
|
def lecun_normal_():
|
|
# adapted from jax.nn.initializers
|
|
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
|
assert fan_in is not None, 'Fan_in is not provided.'
|
|
|
|
std = math.sqrt(1.0 / fan_in)
|
|
return nn.init.trunc_normal_(tensor, std=std / .87962566103423978)
|
|
|
|
return initializer
|