mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
34 lines
1.2 KiB
34 lines
1.2 KiB
3 years ago
|
import math
|
||
|
|
||
|
from torch import Tensor
|
||
|
from torch.nn import init as init
|
||
|
|
||
|
|
||
|
def init_weight_(tensor: Tensor, fan_in: int, fan_out: int = None, init_method: str = 'torch'):
|
||
|
if init_method == 'torch':
|
||
|
a = math.sqrt(5)
|
||
|
nonlinearity = 'leaky_relu'
|
||
|
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
|
||
|
bound = math.sqrt(3.0) * std
|
||
|
init.uniform_(tensor, -bound, bound)
|
||
|
elif init_method == 'jax':
|
||
|
std = math.sqrt(2.0 / float(fan_in + fan_out))
|
||
|
a = math.sqrt(3.0) * std
|
||
|
init.uniform_(tensor, -a, a)
|
||
|
elif init_method == 'jax_embed':
|
||
|
std = math.sqrt(1.0 / fan_in)
|
||
|
init.trunc_normal_(tensor, std=std / .87962566103423978)
|
||
|
elif init_method == 'zero':
|
||
|
init.zeros_(tensor)
|
||
|
|
||
|
def init_bias_(tensor: Tensor, fan_in: int, init_method: str = 'torch'):
|
||
|
if init_method == 'torch':
|
||
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||
|
init.uniform_(tensor, -bound, bound)
|
||
|
elif init_method == 'jax':
|
||
|
init.normal_(tensor, std=1e-6)
|
||
|
elif init_method == 'jax_embed':
|
||
|
init.trunc_normal_(tensor, std=.02)
|
||
|
elif init_method == 'zero':
|
||
|
init.zeros_(tensor)
|