mirror of https://github.com/hpcaitech/ColossalAI
34 lines
1.2 KiB
Python
34 lines
1.2 KiB
Python
import math
|
|
|
|
from torch import Tensor
|
|
from torch.nn import init as init
|
|
|
|
|
|
def init_weight_(tensor: Tensor, fan_in: int, fan_out: int = None, init_method: str = 'torch'):
|
|
if init_method == 'torch':
|
|
a = math.sqrt(5)
|
|
nonlinearity = 'leaky_relu'
|
|
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
|
|
bound = math.sqrt(3.0) * std
|
|
init.uniform_(tensor, -bound, bound)
|
|
elif init_method == 'jax':
|
|
std = math.sqrt(2.0 / float(fan_in + fan_out))
|
|
a = math.sqrt(3.0) * std
|
|
init.uniform_(tensor, -a, a)
|
|
elif init_method == 'jax_embed':
|
|
std = math.sqrt(1.0 / fan_in)
|
|
init.trunc_normal_(tensor, std=std / .87962566103423978)
|
|
elif init_method == 'zero':
|
|
init.zeros_(tensor)
|
|
|
|
def init_bias_(tensor: Tensor, fan_in: int, init_method: str = 'torch'):
|
|
if init_method == 'torch':
|
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
init.uniform_(tensor, -bound, bound)
|
|
elif init_method == 'jax':
|
|
init.normal_(tensor, std=1e-6)
|
|
elif init_method == 'jax_embed':
|
|
init.trunc_normal_(tensor, std=.02)
|
|
elif init_method == 'zero':
|
|
init.zeros_(tensor)
|