mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] refactored layernorm (#4086)
parent
c4b1b65931
commit
d33a44e8c3
|
@ -1,11 +1,11 @@
|
||||||
from .dropout import Dropout1D
|
from .dropout import Dropout1D
|
||||||
from .embedding import Embedding1D, VocabParallelEmbedding1D
|
from .embedding import Embedding1D, VocabParallelEmbedding1D
|
||||||
from .layernorm import LayerNorm1D
|
from .layernorm import FusedLayerNorm
|
||||||
from .linear import Linear1D_Col, Linear1D_Row
|
from .linear import Linear1D_Col, Linear1D_Row
|
||||||
from .linear_conv import LinearConv1D_Col, LinearConv1D_Row
|
from .linear_conv import LinearConv1D_Col, LinearConv1D_Row
|
||||||
from .loss import cross_entropy_1d
|
from .loss import cross_entropy_1d
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row",
|
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row",
|
||||||
"Dropout1D", "cross_entropy_1d", 'LayerNorm1D'
|
"Dropout1D", "cross_entropy_1d", 'FusedLayerNorm'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,89 +1,64 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from typing import List, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed import ProcessGroup
|
|
||||||
|
|
||||||
from colossalai.kernel import LayerNorm
|
__all__ = ['FusedLayerNorm']
|
||||||
from colossalai.nn import init as init
|
|
||||||
|
|
||||||
from .parallel_module import ParallelModule
|
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
||||||
|
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576,
|
||||||
__all__ = ['LayerNorm1D']
|
25600, 30720, 32768, 40960, 49152, 65536
|
||||||
|
]
|
||||||
Fast_LN = None
|
|
||||||
try:
|
|
||||||
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
|
|
||||||
Fast_LN = FastLayerNorm
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm1D(ParallelModule):
|
class FusedLayerNorm():
|
||||||
r"""
|
r"""
|
||||||
Layer Normalization for colossalai
|
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
|
||||||
|
|
||||||
Args:
|
|
||||||
normalized_shape (int): input shape from an expected input of size.
|
|
||||||
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
|
||||||
\times \ldots \times \text{normalized_shape}[-1]]`
|
|
||||||
If a single integer is used, it is treated as a singleton list, and this module will
|
|
||||||
normalize over the last dimension which is expected to be of that specific size.
|
|
||||||
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
|
|
||||||
bias (bool, optional): Whether to add a bias, defaults to ``True``.
|
|
||||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_fast_ln_supported_sizes = [
|
def __init__(self) -> None:
|
||||||
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
|
raise NotImplementedError(
|
||||||
24576, 25600, 30720, 32768, 40960, 49152, 65536
|
'FusedLayerNorm is not implemented as a physical class. '
|
||||||
]
|
'It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex.'
|
||||||
|
)
|
||||||
def __init__(self,
|
|
||||||
normalized_shape: int,
|
|
||||||
eps: int = 1e-05,
|
|
||||||
bias: bool = True,
|
|
||||||
dtype: torch.dtype = None,
|
|
||||||
device: torch.device = None):
|
|
||||||
super().__init__()
|
|
||||||
if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes:
|
|
||||||
norm = Fast_LN(normalized_shape, eps=eps).to(dtype)
|
|
||||||
else:
|
|
||||||
norm = None
|
|
||||||
try:
|
|
||||||
from apex.normalization import FusedLayerNorm
|
|
||||||
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
|
|
||||||
except ImportError:
|
|
||||||
norm = LayerNorm(normalized_shape, eps=eps, device=device, dtype=dtype)
|
|
||||||
self.norm = norm
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(module: nn.LayerNorm, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
|
||||||
**kwargs) -> ParallelModule:
|
|
||||||
r"""
|
r"""
|
||||||
Convert a native pytorch layer norm module to colossalai layer norm module
|
Convert a native pytorch layer norm module to colossalai layer norm module
|
||||||
"""
|
"""
|
||||||
|
# check if apex is installed
|
||||||
|
try:
|
||||||
|
import apex
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel')
|
||||||
|
|
||||||
|
# get the attributes of the module
|
||||||
normalized_shape = module.normalized_shape
|
normalized_shape = module.normalized_shape
|
||||||
eps = module.eps
|
eps = module.eps
|
||||||
bias = module.bias is not None
|
elementwise_affine = module.elementwise_affine
|
||||||
dtype = module.weight.dtype
|
dtype = module.weight.dtype
|
||||||
device = module.weight.device
|
device = module.weight.device
|
||||||
|
|
||||||
# ensure only one process group is passed
|
# pick the suitable layernorm implementation
|
||||||
if isinstance(process_group, (list, tuple)):
|
use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE
|
||||||
assert len(process_group) == 1, \
|
|
||||||
f'Expected only one process group, got {len(process_group)}.'
|
|
||||||
process_group = process_group[0]
|
|
||||||
|
|
||||||
# create layer norm
|
if use_fast_ln:
|
||||||
layer_norm = LayerNorm1D(normalized_shape, eps=eps, bias=bias, device=device, dtype=dtype).norm
|
try:
|
||||||
|
from apex.contrib.layer_norm.layer_norm import FastLayerNorm as ApexFusedLayerNorm
|
||||||
|
except ImportError:
|
||||||
|
# fall back to the normal fused layernorm is not built
|
||||||
|
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
|
||||||
|
else:
|
||||||
|
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
|
||||||
|
|
||||||
|
layernorm = ApexFusedLayerNorm(normalized_shape, eps=eps,
|
||||||
|
elementwise_affine=elementwise_affine).to(dtype).to(device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# copy weight and bias
|
# copy weight and bias
|
||||||
layer_norm.weight.copy_(module.weight)
|
layernorm.weight.copy_(module.weight)
|
||||||
if bias:
|
layernorm.bias.copy_(module.bias)
|
||||||
layer_norm.bias.copy_(module.bias)
|
return layernorm
|
||||||
return layer_norm
|
|
||||||
|
|
|
@ -103,17 +103,17 @@ class BertPolicy(Policy):
|
||||||
base_policy[BertLayer].sub_module_replacement.append(
|
base_policy[BertLayer].sub_module_replacement.append(
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.output.LayerNorm",
|
suffix="attention.output.LayerNorm",
|
||||||
target_module=col_nn.LayerNorm1D,
|
target_module=col_nn.FusedLayerNorm,
|
||||||
))
|
))
|
||||||
base_policy[BertLayer].sub_module_replacement.append(
|
base_policy[BertLayer].sub_module_replacement.append(
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="output.LayerNorm",
|
suffix="output.LayerNorm",
|
||||||
target_module=col_nn.LayerNorm1D,
|
target_module=col_nn.FusedLayerNorm,
|
||||||
))
|
))
|
||||||
base_policy[BertEmbeddings].sub_module_replacement.append(
|
base_policy[BertEmbeddings].sub_module_replacement.append(
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="LayerNorm",
|
suffix="LayerNorm",
|
||||||
target_module=col_nn.LayerNorm1D,
|
target_module=col_nn.FusedLayerNorm,
|
||||||
),)
|
),)
|
||||||
return base_policy
|
return base_policy
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ class BertForPretrainingPolicy(BertPolicy):
|
||||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="transform.LayerNorm",
|
suffix="transform.LayerNorm",
|
||||||
target_module=col_nn.LayerNorm1D,
|
target_module=col_nn.FusedLayerNorm,
|
||||||
))
|
))
|
||||||
module_policy.update(addon_module)
|
module_policy.update(addon_module)
|
||||||
return module_policy
|
return module_policy
|
||||||
|
@ -191,7 +191,7 @@ class BertLMHeadModelPolicy(BertPolicy):
|
||||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="transform.LayerNorm",
|
suffix="transform.LayerNorm",
|
||||||
target_module=col_nn.LayerNorm1D,
|
target_module=col_nn.FusedLayerNorm,
|
||||||
))
|
))
|
||||||
module_policy.update(addon_module)
|
module_policy.update(addon_module)
|
||||||
return module_policy
|
return module_policy
|
||||||
|
@ -228,7 +228,7 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="transform.LayerNorm",
|
suffix="transform.LayerNorm",
|
||||||
target_module=col_nn.LayerNorm1D,
|
target_module=col_nn.FusedLayerNorm,
|
||||||
))
|
))
|
||||||
module_policy.update(addon_module)
|
module_policy.update(addon_module)
|
||||||
return module_policy
|
return module_policy
|
||||||
|
|
|
@ -1,16 +1,15 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.shardformer.layer import LayerNorm1D
|
from colossalai.shardformer.layer import FusedLayerNorm
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
def check_layernorm_1d():
|
def check_layernorm():
|
||||||
norm = nn.LayerNorm(128, 0.00001).cuda()
|
norm = nn.LayerNorm(128, 0.00001).cuda()
|
||||||
norm1d = LayerNorm1D.from_native_module(norm, process_group=None)
|
norm1d = FusedLayerNorm.from_native_module(norm, process_group=None)
|
||||||
|
|
||||||
assert norm1d.weight.shape == torch.Size([128])
|
assert norm1d.weight.shape == torch.Size([128])
|
||||||
|
|
||||||
|
@ -33,11 +32,11 @@ def check_layernorm_1d():
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
check_layernorm_1d()
|
check_layernorm()
|
||||||
|
|
||||||
|
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_layernorm_1d():
|
def test_layernorm():
|
||||||
spawn(run_dist, nprocs=2)
|
spawn(run_dist, nprocs=2)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue