ColossalAI/colossalai/kernel/cuda_native/layer_norm.py

79 lines
2.5 KiB
Python
Raw Normal View History

2021-12-21 04:19:52 +00:00
"""This code is from NVIDIA apex:
https://github.com/NVIDIA/apex
with some changes. """
import numbers
2021-12-21 04:19:52 +00:00
import torch
from torch.cuda.amp import custom_bwd, custom_fwd
2021-12-21 04:19:52 +00:00
from torch.nn import init
from torch.nn.parameter import Parameter
2021-12-21 04:19:52 +00:00
from colossalai.kernel.op_builder.layernorm import LayerNormBuilder
try:
from colossalai._C import layer_norm
except ImportError:
layer_norm = None
2021-12-21 04:19:52 +00:00
class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod
2022-01-20 05:44:51 +00:00
@custom_fwd(cast_inputs=torch.float32)
2021-12-21 04:19:52 +00:00
def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
global layer_norm
if layer_norm is None:
layer_norm = LayerNormBuilder().load()
2023-01-07 10:23:02 +00:00
output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.layernorm_op = layer_norm
2021-12-21 04:19:52 +00:00
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
2022-01-20 05:44:51 +00:00
@custom_bwd
2021-12-21 04:19:52 +00:00
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
2023-01-07 10:23:02 +00:00
= layer_norm.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
2021-12-21 04:19:52 +00:00
return grad_input, grad_weight, grad_bias, None, None
class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None):
2021-12-21 04:19:52 +00:00
super(MixedFusedLayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype))
self.bias = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype))
2021-12-21 04:19:52 +00:00
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
2022-01-20 05:44:51 +00:00
def __repr__(self):
return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})'