[shardformer] Add layernorm (#4072)

* add layernorm to bert

* add layernorm test

* add layernorm test with load state dict

* add use_mixedfusedLN in shard config

* refactor policy to support fused_layernorm
pull/4157/head
FoolPlayer 2023-06-23 18:00:22 +08:00 committed by Frank Lee
parent 70c58cfd4f
commit 92f6791095
7 changed files with 252 additions and 17 deletions

View File

@ -1,10 +1,11 @@
from .dropout import Dropout1D
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .layernorm import LayerNorm1D
from .linear import Linear1D_Col, Linear1D_Row
from .linear_conv import LinearConv1D_Col, LinearConv1D_Row
from .loss import cross_entropy_1d
__all__ = [
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row",
"Dropout1D", "cross_entropy_1d"
"Dropout1D", "cross_entropy_1d", 'LayerNorm1D'
]

View File

@ -0,0 +1,89 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import List, Union
import torch
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.kernel import LayerNorm
from colossalai.nn import init as init
from .parallel_module import ParallelModule
__all__ = ['LayerNorm1D']
Fast_LN = None
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
Fast_LN = FastLayerNorm
except ImportError:
pass
class LayerNorm1D(ParallelModule):
r"""
Layer Normalization for colossalai
Args:
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
_fast_ln_supported_sizes = [
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
24576, 25600, 30720, 32768, 40960, 49152, 65536
]
def __init__(self,
normalized_shape: int,
eps: int = 1e-05,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()
if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes:
norm = Fast_LN(normalized_shape, eps=eps).to(dtype)
else:
norm = None
try:
from apex.normalization import FusedLayerNorm
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
except ImportError:
norm = LayerNorm(normalized_shape, eps=eps, device=device, dtype=dtype)
self.norm = norm
@staticmethod
def from_native_module(module: nn.LayerNorm, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
r"""
Convert a native pytorch layer norm module to colossalai layer norm module
"""
normalized_shape = module.normalized_shape
eps = module.eps
bias = module.bias is not None
dtype = module.weight.dtype
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
# create layer norm
layer_norm = LayerNorm1D(normalized_shape, eps=eps, bias=bias, device=device, dtype=dtype).norm
with torch.no_grad():
# copy weight and bias
layer_norm.weight.copy_(module.weight)
if bias:
layer_norm.bias.copy_(module.bias)
return layer_norm

View File

@ -1,8 +1,14 @@
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
from transformers.models.bert.modeling_bert import (
BertEmbeddings,
BertForMultipleChoice,
BertForSequenceClassification,
BertForTokenClassification,
BertLayer,
BertLMPredictionHead,
)
import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.layer.dropout import Dropout1D
from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -24,7 +30,7 @@ class BertPolicy(Policy):
return self.model
def module_policy(self):
return {
base_policy = {
BertLayer:
ModulePolicyDescription(
attribute_replacement={
@ -53,10 +59,18 @@ class BertPolicy(Policy):
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
target_module=col_nn.Dropout1D,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.Dropout1D,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
@ -66,12 +80,8 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
target_module=Dropout1D,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=Dropout1D,
suffix="output.dropout",
target_module=col_nn.Dropout1D,
)
]),
BertEmbeddings:
@ -81,10 +91,32 @@ class BertPolicy(Policy):
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.Dropout1D,
)
])
}
if self.shard_config.fused_layernorm:
base_policy[BertLayer].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=col_nn.LayerNorm1D,
))
base_policy[BertLayer].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="output.LayerNorm",
target_module=col_nn.LayerNorm1D,
))
base_policy[BertEmbeddings].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="LayerNorm",
target_module=col_nn.LayerNorm1D,
),)
return base_policy
def new_model_class(self):
# do nothing
return self.model
@ -115,9 +147,15 @@ class BertForPretrainingPolicy(BertPolicy):
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
kwargs={"gather_output": True}),
])
}
if self.shard_config.fused_layernorm:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.LayerNorm1D,
))
module_policy.update(addon_module)
return module_policy
@ -146,9 +184,15 @@ class BertLMHeadModelPolicy(BertPolicy):
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
kwargs={"gather_output": True}),
])
}
if self.shard_config.fused_layernorm:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.LayerNorm1D,
))
module_policy.update(addon_module)
return module_policy
@ -177,9 +221,15 @@ class BertForMaskedLMPolicy(BertPolicy):
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
kwargs={"gather_output": True}),
])
}
if self.shard_config.fused_layernorm:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.LayerNorm1D,
))
module_policy.update(addon_module)
return module_policy
@ -199,6 +249,22 @@ class BertForSequenceClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertForSequenceClassification:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.Dropout1D,
)
])
}
module_policy.update(addon_module)
return module_policy
# BertForTokenClassification
class BertForTokenClassificationPolicy(BertPolicy):
@ -206,6 +272,22 @@ class BertForTokenClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertForTokenClassification:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.Dropout1D,
)
])
}
module_policy.update(addon_module)
return module_policy
# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):
@ -219,3 +301,19 @@ class BertForMultipleChoicePolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertForMultipleChoice:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.Dropout1D,
)
])
}
module_policy.update(addon_module)
return module_policy

View File

@ -11,8 +11,9 @@ class ShardConfig:
The config for sharding the huggingface model
Args:
data_parallel_size (int): The size of data parallel
tensor_parallel_size (int): The size of tensor parallel
use_mixedfusedLN (bool): Whether to use the `MixedFusedLayerNorm`
data_parallel_size (int): The size of data parallel
pipeline_parallel_size (int): The size of pipeline parallel
tensor_parallel_mode (List): The mode of tensor parallel, choose from `['1d','2d','2.5d','3d']
inference_only (bool): Whether to use the inference only mode, when setting to `True`, the model
@ -20,6 +21,7 @@ class ShardConfig:
gather_output (bool): Whether to gather the output of the model of the last layer
"""
tensor_parallel_size: int
fused_layernorm: bool = False
# TODO: add support for tensor parallel
# pipeline_parallel_size: int

View File

@ -0,0 +1,45 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer import LayerNorm1D
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_layernorm_1d():
norm = nn.LayerNorm(128, 0.00001).cuda()
norm1d = LayerNorm1D.from_native_module(norm, process_group=None)
assert norm1d.weight.shape == torch.Size([128])
# ensure state dict is reversibly loadable
norm.load_state_dict(norm1d.state_dict())
norm1d.load_state_dict(norm.state_dict())
# check computation correctness
x = torch.rand(4, 128).cuda()
out = norm(x)
gather_out = norm1d(x)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
assert_close(norm.weight.grad, norm1d.weight.grad)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_layernorm_1d()
@rerun_if_address_is_in_use()
def test_layernorm_1d():
spawn(run_dist, nprocs=2)
if __name__ == '__main__':
test_layernorm_1d()

View File

@ -77,7 +77,7 @@ def check_linear_conv_1d_col():
assert_close(target_grad, linear_conv_col.weight.grad)
def check_linear_1d_row():
def check_linear_conv_1d_row():
linear = Conv1D(192, 48).cuda()
linear_row = LinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
@ -103,7 +103,7 @@ def check_linear_1d_row():
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_linear_conv_1d_col()
check_linear_1d_row()
check_linear_conv_1d_row()
@rerun_if_address_is_in_use()

View File

@ -8,7 +8,7 @@ def build_model(world_size, model_fn):
org_model = model_fn().cuda()
# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size)
shard_config = ShardConfig(tensor_parallel_size=world_size, fused_layernorm=True)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()