InternLM/internlm/model/norm.py

45 lines
1.4 KiB
Python

# adopted from https://github.com/NVIDIA/apex/blob/master/apex/normalization/fused_layer_norm
import numbers
import torch
from torch.nn import init
from torch.nn.parameter import Parameter
def manual_rms_norm(input, normalized_shape, weight, eps):
# layer norm should always be calculated in float32
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True)
input = input * torch.rsqrt(variance + eps)
if weight is None:
return input
# convert into half-precision if necessary
if weight.dtype in [torch.float16, torch.bfloat16]:
input = input.to(weight.dtype)
return weight * input
class RMSNormTorch(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5):
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.empty(*normalized_shape))
self.reset_parameters()
def forward(self, input: torch.Tensor):
return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)
def reset_parameters(self):
init.ones_(self.weight)
def extra_repr(self):
return "{normalized_shape}, eps={eps}, ".format(**self.__dict__)