mirror of https://github.com/InternLM/InternLM
47 lines
1.5 KiB
Python
47 lines
1.5 KiB
Python
# adopted from https://github.com/NVIDIA/apex/blob/master/apex/normalization/fused_layer_norm
|
|
|
|
import numbers
|
|
|
|
import torch
|
|
from torch.nn import init
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
def manual_rms_norm(my_input, normalized_shape, weight, eps):
|
|
# layer norm should always be calculated in float32
|
|
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
|
|
variance = my_input.to(torch.float32).pow(2).mean(dims, keepdim=True)
|
|
my_input = my_input * torch.rsqrt(variance + eps)
|
|
|
|
if weight is None:
|
|
return my_input
|
|
|
|
# convert into half-precision if necessary
|
|
if weight.dtype in [torch.float16, torch.bfloat16]:
|
|
my_input = my_input.to(weight.dtype)
|
|
|
|
return weight * my_input
|
|
|
|
|
|
class RMSNormTorch(torch.nn.Module):
|
|
"""A custom PyTorch module for RMS normalization."""
|
|
|
|
def __init__(self, normalized_shape, eps=1e-5):
|
|
super().__init__()
|
|
|
|
if isinstance(normalized_shape, numbers.Integral):
|
|
normalized_shape = (normalized_shape,)
|
|
self.normalized_shape = torch.Size(normalized_shape)
|
|
self.eps = eps
|
|
self.weight = Parameter(torch.empty(*normalized_shape))
|
|
self.reset_parameters()
|
|
|
|
def forward(self, _input: torch.Tensor):
|
|
return manual_rms_norm(_input, self.normalized_shape, self.weight, self.eps)
|
|
|
|
def reset_parameters(self):
|
|
init.ones_(self.weight)
|
|
|
|
def extra_repr(self):
|
|
return "{normalized_shape}, eps={eps}, ".format(**self.__dict__)
|