diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 3ce0ef68a..3ece25831 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,11 +1,11 @@ from .dropout import Dropout1D from .embedding import Embedding1D, VocabParallelEmbedding1D -from .layernorm import LayerNorm1D +from .layernorm import FusedLayerNorm from .linear import Linear1D_Col, Linear1D_Row from .linear_conv import LinearConv1D_Col, LinearConv1D_Row from .loss import cross_entropy_1d __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row", - "Dropout1D", "cross_entropy_1d", 'LayerNorm1D' + "Dropout1D", "cross_entropy_1d", 'FusedLayerNorm' ] diff --git a/colossalai/shardformer/layer/layernorm.py b/colossalai/shardformer/layer/layernorm.py index a8e1d7a2c..83854239c 100644 --- a/colossalai/shardformer/layer/layernorm.py +++ b/colossalai/shardformer/layer/layernorm.py @@ -1,89 +1,64 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import List, Union - import torch import torch.nn as nn -from torch.distributed import ProcessGroup - -from colossalai.kernel import LayerNorm -from colossalai.nn import init as init -from .parallel_module import ParallelModule +__all__ = ['FusedLayerNorm'] -__all__ = ['LayerNorm1D'] +FAST_LAYERNORM_SUPPORTED_SIZE = [ + 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576, + 25600, 30720, 32768, 40960, 49152, 65536 +] -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass - -class LayerNorm1D(ParallelModule): +class FusedLayerNorm(): r""" - Layer Normalization for colossalai - - Args: - normalized_shape (int): input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] - \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. - bias (bool, optional): Whether to add a bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. """ - _fast_ln_supported_sizes = [ - 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, - 24576, 25600, 30720, 32768, 40960, 49152, 65536 - ] - - def __init__(self, - normalized_shape: int, - eps: int = 1e-05, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None): - super().__init__() - if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: - norm = Fast_LN(normalized_shape, eps=eps).to(dtype) - else: - norm = None - try: - from apex.normalization import FusedLayerNorm - norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) - except ImportError: - norm = LayerNorm(normalized_shape, eps=eps, device=device, dtype=dtype) - self.norm = norm + def __init__(self) -> None: + raise NotImplementedError( + 'FusedLayerNorm is not implemented as a physical class. ' + 'It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex.' + ) @staticmethod - def from_native_module(module: nn.LayerNorm, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: + def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: r""" Convert a native pytorch layer norm module to colossalai layer norm module """ + # check if apex is installed + try: + import apex + except ImportError: + raise ImportError( + 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel') + + # get the attributes of the module normalized_shape = module.normalized_shape eps = module.eps - bias = module.bias is not None + elementwise_affine = module.elementwise_affine dtype = module.weight.dtype device = module.weight.device - # ensure only one process group is passed - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' - process_group = process_group[0] + # pick the suitable layernorm implementation + use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE + + if use_fast_ln: + try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm as ApexFusedLayerNorm + except ImportError: + # fall back to the normal fused layernorm is not built + from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm + else: + from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm - # create layer norm - layer_norm = LayerNorm1D(normalized_shape, eps=eps, bias=bias, device=device, dtype=dtype).norm + layernorm = ApexFusedLayerNorm(normalized_shape, eps=eps, + elementwise_affine=elementwise_affine).to(dtype).to(device) with torch.no_grad(): # copy weight and bias - layer_norm.weight.copy_(module.weight) - if bias: - layer_norm.bias.copy_(module.bias) - return layer_norm + layernorm.weight.copy_(module.weight) + layernorm.bias.copy_(module.bias) + return layernorm diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 1baf67ef9..7b0eaa5d8 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -103,17 +103,17 @@ class BertPolicy(Policy): base_policy[BertLayer].sub_module_replacement.append( SubModuleReplacementDescription( suffix="attention.output.LayerNorm", - target_module=col_nn.LayerNorm1D, + target_module=col_nn.FusedLayerNorm, )) base_policy[BertLayer].sub_module_replacement.append( SubModuleReplacementDescription( suffix="output.LayerNorm", - target_module=col_nn.LayerNorm1D, + target_module=col_nn.FusedLayerNorm, )) base_policy[BertEmbeddings].sub_module_replacement.append( SubModuleReplacementDescription( suffix="LayerNorm", - target_module=col_nn.LayerNorm1D, + target_module=col_nn.FusedLayerNorm, ),) return base_policy @@ -154,7 +154,7 @@ class BertForPretrainingPolicy(BertPolicy): addon_module[BertLMPredictionHead].sub_module_replacement.append( SubModuleReplacementDescription( suffix="transform.LayerNorm", - target_module=col_nn.LayerNorm1D, + target_module=col_nn.FusedLayerNorm, )) module_policy.update(addon_module) return module_policy @@ -191,7 +191,7 @@ class BertLMHeadModelPolicy(BertPolicy): addon_module[BertLMPredictionHead].sub_module_replacement.append( SubModuleReplacementDescription( suffix="transform.LayerNorm", - target_module=col_nn.LayerNorm1D, + target_module=col_nn.FusedLayerNorm, )) module_policy.update(addon_module) return module_policy @@ -228,7 +228,7 @@ class BertForMaskedLMPolicy(BertPolicy): addon_module[BertLMPredictionHead].sub_module_replacement.append( SubModuleReplacementDescription( suffix="transform.LayerNorm", - target_module=col_nn.LayerNorm1D, + target_module=col_nn.FusedLayerNorm, )) module_policy.update(addon_module) return module_policy diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py index 334ae05be..a11784554 100644 --- a/tests/test_shardformer/test_layer/test_layernorm.py +++ b/tests/test_shardformer/test_layer/test_layernorm.py @@ -1,16 +1,15 @@ import torch -import torch.distributed as dist import torch.nn as nn from torch.testing import assert_close import colossalai -from colossalai.shardformer.layer import LayerNorm1D +from colossalai.shardformer.layer import FusedLayerNorm from colossalai.testing import rerun_if_address_is_in_use, spawn -def check_layernorm_1d(): +def check_layernorm(): norm = nn.LayerNorm(128, 0.00001).cuda() - norm1d = LayerNorm1D.from_native_module(norm, process_group=None) + norm1d = FusedLayerNorm.from_native_module(norm, process_group=None) assert norm1d.weight.shape == torch.Size([128]) @@ -33,11 +32,11 @@ def check_layernorm_1d(): def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_layernorm_1d() + check_layernorm() @rerun_if_address_is_in_use() -def test_layernorm_1d(): +def test_layernorm(): spawn(run_dist, nprocs=2)