mirror of https://github.com/hpcaitech/ColossalAI
Frank Lee
1 year ago
4 changed files with 51 additions and 77 deletions
@ -1,11 +1,11 @@
|
||||
from .dropout import Dropout1D |
||||
from .embedding import Embedding1D, VocabParallelEmbedding1D |
||||
from .layernorm import LayerNorm1D |
||||
from .layernorm import FusedLayerNorm |
||||
from .linear import Linear1D_Col, Linear1D_Row |
||||
from .linear_conv import LinearConv1D_Col, LinearConv1D_Row |
||||
from .loss import cross_entropy_1d |
||||
|
||||
__all__ = [ |
||||
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row", |
||||
"Dropout1D", "cross_entropy_1d", 'LayerNorm1D' |
||||
"Dropout1D", "cross_entropy_1d", 'FusedLayerNorm' |
||||
] |
||||
|
@ -1,89 +1,64 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
from typing import List, Union |
||||
|
||||
import torch |
||||
import torch.nn as nn |
||||
from torch.distributed import ProcessGroup |
||||
|
||||
from colossalai.kernel import LayerNorm |
||||
from colossalai.nn import init as init |
||||
|
||||
from .parallel_module import ParallelModule |
||||
|
||||
__all__ = ['LayerNorm1D'] |
||||
__all__ = ['FusedLayerNorm'] |
||||
|
||||
Fast_LN = None |
||||
try: |
||||
from apex.contrib.layer_norm.layer_norm import FastLayerNorm |
||||
Fast_LN = FastLayerNorm |
||||
except ImportError: |
||||
pass |
||||
FAST_LAYERNORM_SUPPORTED_SIZE = [ |
||||
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576, |
||||
25600, 30720, 32768, 40960, 49152, 65536 |
||||
] |
||||
|
||||
|
||||
class LayerNorm1D(ParallelModule): |
||||
class FusedLayerNorm(): |
||||
r""" |
||||
Layer Normalization for colossalai |
||||
|
||||
Args: |
||||
normalized_shape (int): input shape from an expected input of size. |
||||
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] |
||||
\times \ldots \times \text{normalized_shape}[-1]]` |
||||
If a single integer is used, it is treated as a singleton list, and this module will |
||||
normalize over the last dimension which is expected to be of that specific size. |
||||
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. |
||||
bias (bool, optional): Whether to add a bias, defaults to ``True``. |
||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. |
||||
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. |
||||
""" |
||||
|
||||
_fast_ln_supported_sizes = [ |
||||
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, |
||||
24576, 25600, 30720, 32768, 40960, 49152, 65536 |
||||
] |
||||
|
||||
def __init__(self, |
||||
normalized_shape: int, |
||||
eps: int = 1e-05, |
||||
bias: bool = True, |
||||
dtype: torch.dtype = None, |
||||
device: torch.device = None): |
||||
super().__init__() |
||||
if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: |
||||
norm = Fast_LN(normalized_shape, eps=eps).to(dtype) |
||||
else: |
||||
norm = None |
||||
try: |
||||
from apex.normalization import FusedLayerNorm |
||||
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) |
||||
except ImportError: |
||||
norm = LayerNorm(normalized_shape, eps=eps, device=device, dtype=dtype) |
||||
self.norm = norm |
||||
def __init__(self) -> None: |
||||
raise NotImplementedError( |
||||
'FusedLayerNorm is not implemented as a physical class. ' |
||||
'It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex.' |
||||
) |
||||
|
||||
@staticmethod |
||||
def from_native_module(module: nn.LayerNorm, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, |
||||
**kwargs) -> ParallelModule: |
||||
def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: |
||||
r""" |
||||
Convert a native pytorch layer norm module to colossalai layer norm module |
||||
""" |
||||
# check if apex is installed |
||||
try: |
||||
import apex |
||||
except ImportError: |
||||
raise ImportError( |
||||
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel') |
||||
|
||||
# get the attributes of the module |
||||
normalized_shape = module.normalized_shape |
||||
eps = module.eps |
||||
bias = module.bias is not None |
||||
elementwise_affine = module.elementwise_affine |
||||
dtype = module.weight.dtype |
||||
device = module.weight.device |
||||
|
||||
# ensure only one process group is passed |
||||
if isinstance(process_group, (list, tuple)): |
||||
assert len(process_group) == 1, \ |
||||
f'Expected only one process group, got {len(process_group)}.' |
||||
process_group = process_group[0] |
||||
# pick the suitable layernorm implementation |
||||
use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE |
||||
|
||||
if use_fast_ln: |
||||
try: |
||||
from apex.contrib.layer_norm.layer_norm import FastLayerNorm as ApexFusedLayerNorm |
||||
except ImportError: |
||||
# fall back to the normal fused layernorm is not built |
||||
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm |
||||
else: |
||||
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm |
||||
|
||||
# create layer norm |
||||
layer_norm = LayerNorm1D(normalized_shape, eps=eps, bias=bias, device=device, dtype=dtype).norm |
||||
layernorm = ApexFusedLayerNorm(normalized_shape, eps=eps, |
||||
elementwise_affine=elementwise_affine).to(dtype).to(device) |
||||
|
||||
with torch.no_grad(): |
||||
# copy weight and bias |
||||
layer_norm.weight.copy_(module.weight) |
||||
if bias: |
||||
layer_norm.bias.copy_(module.bias) |
||||
return layer_norm |
||||
layernorm.weight.copy_(module.weight) |
||||
layernorm.bias.copy_(module.bias) |
||||
return layernorm |
||||
|
Loading…
Reference in new issue