mirror of https://github.com/hpcaitech/ColossalAI
46 lines
1.6 KiB
Python
46 lines
1.6 KiB
Python
import torch
|
|
|
|
try:
|
|
import fused_mix_prec_layer_norm_cuda
|
|
except:
|
|
fused_mix_prec_layer_norm_cuda = None
|
|
|
|
|
|
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
|
r"""
|
|
Layernorm
|
|
|
|
:param input: input maxtrix
|
|
:param weight: weight matrix
|
|
:param bias: bias matrix
|
|
:param normalized_shape: input shape from an expected input
|
|
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
|
If a single integer is used, it is treated as a singleton list, and this module will
|
|
normalize over the last dimension which is expected to be of that specific size.
|
|
:param eps: a value added to the denominator for numerical stability
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
|
ctx.normalized_shape = normalized_shape
|
|
ctx.eps = eps
|
|
input_ = input.contiguous()
|
|
weight_ = weight.contiguous()
|
|
bias_ = bias.contiguous()
|
|
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
|
|
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
|
|
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
|
return output
|
|
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
|
grad_input = grad_weight = grad_bias = None
|
|
grad_input, grad_weight, grad_bias \
|
|
= fused_mix_prec_layer_norm_cuda.backward_affine(
|
|
grad_output.contiguous(), mean, invvar,
|
|
input_, ctx.normalized_shape,
|
|
weight_, bias_, ctx.eps)
|
|
|
|
return grad_input, grad_weight, grad_bias, None, None |