ColossalAI/colossalai/shardformer/layer/layernorm.py

90 lines
3.2 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import List, Union
import torch
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.kernel import LayerNorm
from colossalai.nn import init as init
from .parallel_module import ParallelModule
__all__ = ['LayerNorm1D']
Fast_LN = None
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
Fast_LN = FastLayerNorm
except ImportError:
pass
class LayerNorm1D(ParallelModule):
r"""
Layer Normalization for colossalai
Args:
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
_fast_ln_supported_sizes = [
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
24576, 25600, 30720, 32768, 40960, 49152, 65536
]
def __init__(self,
normalized_shape: int,
eps: int = 1e-05,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()
if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes:
norm = Fast_LN(normalized_shape, eps=eps).to(dtype)
else:
norm = None
try:
from apex.normalization import FusedLayerNorm
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
except ImportError:
norm = LayerNorm(normalized_shape, eps=eps, device=device, dtype=dtype)
self.norm = norm
@staticmethod
def from_native_module(module: nn.LayerNorm, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
r"""
Convert a native pytorch layer norm module to colossalai layer norm module
"""
normalized_shape = module.normalized_shape
eps = module.eps
bias = module.bias is not None
dtype = module.weight.dtype
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
# create layer norm
layer_norm = LayerNorm1D(normalized_shape, eps=eps, bias=bias, device=device, dtype=dtype).norm
with torch.no_grad():
# copy weight and bias
layer_norm.weight.copy_(module.weight)
if bias:
layer_norm.bias.copy_(module.bias)
return layer_norm