mirror of https://github.com/InternLM/InternLM
feat(model): implement uniform_init for tensor. (#252)
* Implement uniform_init for tensor. * Fix functinal calling bugs: normal->uniform. * Format editting: remove unused torch importing.pull/260/head
parent
c92aa06bd8
commit
f79586b0c6
|
@ -3,16 +3,15 @@
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
def scaled_init_method_normal(sigma, num_layers):
|
def scaled_init_method_normal(sigma: float = 1.0, num_layers: int = 1):
|
||||||
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
|
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
|
||||||
std = sigma / math.sqrt(2.0 * num_layers)
|
std = sigma / math.sqrt(2.0 * num_layers)
|
||||||
|
|
||||||
def init_(tensor):
|
def init_(tensor):
|
||||||
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
return nn.init.normal_(tensor, mean=0.0, std=std)
|
||||||
|
|
||||||
return init_
|
return init_
|
||||||
|
|
||||||
|
@ -32,3 +31,33 @@ def normal_(mean: float = 0.0, std: float = 1.0):
|
||||||
return nn.init.normal_(tensor, mean, std)
|
return nn.init.normal_(tensor, mean, std)
|
||||||
|
|
||||||
return initializer
|
return initializer
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_init_method_uniform(sigma: float = 1.0, num_layers: int = 1):
|
||||||
|
"""Init method based on p(x)=Uniform(-a, a) where std(x)=sigma/sqrt(2*num_layers)."""
|
||||||
|
std = sigma / math.sqrt(2.0 * num_layers)
|
||||||
|
a = math.sqrt(3.0 * std)
|
||||||
|
|
||||||
|
def init_(tensor):
|
||||||
|
return nn.init.uniform_(tensor, -a, a)
|
||||||
|
|
||||||
|
return init_
|
||||||
|
|
||||||
|
|
||||||
|
def uniform_(mean: float = 0.0, std: float = 1.0):
|
||||||
|
r"""Return the initializer filling the input Tensor with values drawn from the uniform distribution
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\mathcal{U}(mean-a, mean+a), where a satisfies \mathcal{U}_{std}=std.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mean (float): the mean of the uniform distribution. Defaults 0.0.
|
||||||
|
std (float): the standard deviation of the uniform distribution. Defaults 1.0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
a = math.sqrt(3.0 * std)
|
||||||
|
|
||||||
|
def initializer(tensor: Tensor):
|
||||||
|
return nn.init.uniform_(tensor, mean - a, mean + a)
|
||||||
|
|
||||||
|
return initializer
|
||||||
|
|
Loading…
Reference in New Issue