mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
72 lines
2.4 KiB
72 lines
2.4 KiB
3 years ago
|
"""This code is from NVIDIA apex:
|
||
|
https://github.com/NVIDIA/apex
|
||
|
with some changes. """
|
||
|
|
||
|
import numbers
|
||
2 years ago
|
|
||
3 years ago
|
import torch
|
||
2 years ago
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||
3 years ago
|
from torch.nn import init
|
||
2 years ago
|
from torch.nn.parameter import Parameter
|
||
3 years ago
|
|
||
10 months ago
|
from colossalai.kernel.kernel_loader import LayerNormLoader
|
||
2 years ago
|
|
||
|
try:
|
||
|
from colossalai._C import layer_norm
|
||
|
except ImportError:
|
||
|
layer_norm = None
|
||
|
|
||
3 years ago
|
|
||
|
class FusedLayerNormAffineFunction(torch.autograd.Function):
|
||
|
@staticmethod
|
||
3 years ago
|
@custom_fwd(cast_inputs=torch.float32)
|
||
3 years ago
|
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
||
|
ctx.normalized_shape = normalized_shape
|
||
|
ctx.eps = eps
|
||
|
input_ = input.contiguous()
|
||
|
weight_ = weight.contiguous()
|
||
|
bias_ = bias.contiguous()
|
||
2 years ago
|
|
||
|
global layer_norm
|
||
|
if layer_norm is None:
|
||
10 months ago
|
layer_norm = LayerNormLoader().load()
|
||
2 years ago
|
output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
|
||
2 years ago
|
ctx.layernorm_op = layer_norm
|
||
3 years ago
|
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
||
|
|
||
|
return output
|
||
|
|
||
|
@staticmethod
|
||
3 years ago
|
@custom_bwd
|
||
3 years ago
|
def backward(ctx, grad_output):
|
||
|
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
||
|
grad_input = grad_weight = grad_bias = None
|
||
1 year ago
|
grad_input, grad_weight, grad_bias = layer_norm.backward_affine(
|
||
|
grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps
|
||
|
)
|
||
3 years ago
|
|
||
|
return grad_input, grad_weight, grad_bias, None, None
|
||
|
|
||
|
|
||
|
class MixedFusedLayerNorm(torch.nn.Module):
|
||
3 years ago
|
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None):
|
||
3 years ago
|
super(MixedFusedLayerNorm, self).__init__()
|
||
|
|
||
|
if isinstance(normalized_shape, numbers.Integral):
|
||
|
normalized_shape = (normalized_shape,)
|
||
|
self.normalized_shape = torch.Size(normalized_shape)
|
||
|
self.eps = eps
|
||
3 years ago
|
self.weight = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype))
|
||
|
self.bias = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype))
|
||
3 years ago
|
self.reset_parameters()
|
||
|
|
||
|
def reset_parameters(self):
|
||
|
init.ones_(self.weight)
|
||
|
init.zeros_(self.bias)
|
||
|
|
||
|
def forward(self, input):
|
||
3 years ago
|
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
|
||
3 years ago
|
|
||
|
def __repr__(self):
|
||
1 year ago
|
return f"MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})"
|