From 92f6791095491e44c5712e14f00f2e19b52dc9f6 Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Fri, 23 Jun 2023 18:00:22 +0800 Subject: [PATCH] [shardformer] Add layernorm (#4072) * add layernorm to bert * add layernorm test * add layernorm test with load state dict * add use_mixedfusedLN in shard config * refactor policy to support fused_layernorm --- colossalai/shardformer/layer/__init__.py | 3 +- colossalai/shardformer/layer/layernorm.py | 89 +++++++++++++ colossalai/shardformer/policies/bert.py | 122 ++++++++++++++++-- colossalai/shardformer/shard/shard_config.py | 4 +- .../test_layer/test_layernorm.py | 45 +++++++ .../test_layer/test_linearconv_1d.py | 4 +- tests/test_shardformer/test_model/_utils.py | 2 +- 7 files changed, 252 insertions(+), 17 deletions(-) create mode 100644 colossalai/shardformer/layer/layernorm.py create mode 100644 tests/test_shardformer/test_layer/test_layernorm.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 808ebbc12..3ce0ef68a 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,10 +1,11 @@ from .dropout import Dropout1D from .embedding import Embedding1D, VocabParallelEmbedding1D +from .layernorm import LayerNorm1D from .linear import Linear1D_Col, Linear1D_Row from .linear_conv import LinearConv1D_Col, LinearConv1D_Row from .loss import cross_entropy_1d __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row", - "Dropout1D", "cross_entropy_1d" + "Dropout1D", "cross_entropy_1d", 'LayerNorm1D' ] diff --git a/colossalai/shardformer/layer/layernorm.py b/colossalai/shardformer/layer/layernorm.py new file mode 100644 index 000000000..a8e1d7a2c --- /dev/null +++ b/colossalai/shardformer/layer/layernorm.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import List, Union + +import torch +import torch.nn as nn +from torch.distributed import ProcessGroup + +from colossalai.kernel import LayerNorm +from colossalai.nn import init as init + +from .parallel_module import ParallelModule + +__all__ = ['LayerNorm1D'] + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class LayerNorm1D(ParallelModule): + r""" + Layer Normalization for colossalai + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + _fast_ln_supported_sizes = [ + 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, + 24576, 25600, 30720, 32768, 40960, 49152, 65536 + ] + + def __init__(self, + normalized_shape: int, + eps: int = 1e-05, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: + norm = Fast_LN(normalized_shape, eps=eps).to(dtype) + else: + norm = None + try: + from apex.normalization import FusedLayerNorm + norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) + except ImportError: + norm = LayerNorm(normalized_shape, eps=eps, device=device, dtype=dtype) + self.norm = norm + + @staticmethod + def from_native_module(module: nn.LayerNorm, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native pytorch layer norm module to colossalai layer norm module + """ + normalized_shape = module.normalized_shape + eps = module.eps + bias = module.bias is not None + dtype = module.weight.dtype + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + # create layer norm + layer_norm = LayerNorm1D(normalized_shape, eps=eps, bias=bias, device=device, dtype=dtype).norm + + with torch.no_grad(): + # copy weight and bias + layer_norm.weight.copy_(module.weight) + if bias: + layer_norm.bias.copy_(module.bias) + return layer_norm diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 8649c0dbe..1baf67ef9 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,8 +1,14 @@ import torch.nn as nn -from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead +from transformers.models.bert.modeling_bert import ( + BertEmbeddings, + BertForMultipleChoice, + BertForSequenceClassification, + BertForTokenClassification, + BertLayer, + BertLMPredictionHead, +) import colossalai.shardformer.layer as col_nn -from colossalai.shardformer.layer.dropout import Dropout1D from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -24,7 +30,7 @@ class BertPolicy(Policy): return self.model def module_policy(self): - return { + base_policy = { BertLayer: ModulePolicyDescription( attribute_replacement={ @@ -53,10 +59,18 @@ class BertPolicy(Policy): suffix="attention.self.value", target_module=col_nn.Linear1D_Col, ), + SubModuleReplacementDescription( + suffix="attention.self.dropout", + target_module=col_nn.Dropout1D, + ), SubModuleReplacementDescription( suffix="attention.output.dense", target_module=col_nn.Linear1D_Row, ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.Dropout1D, + ), SubModuleReplacementDescription( suffix="intermediate.dense", target_module=col_nn.Linear1D_Col, @@ -66,12 +80,8 @@ class BertPolicy(Policy): target_module=col_nn.Linear1D_Row, ), SubModuleReplacementDescription( - suffix="attention.self.dropout", - target_module=Dropout1D, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=Dropout1D, + suffix="output.dropout", + target_module=col_nn.Dropout1D, ) ]), BertEmbeddings: @@ -81,10 +91,32 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="word_embeddings", target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.Dropout1D, ) ]) } + if self.shard_config.fused_layernorm: + base_policy[BertLayer].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="attention.output.LayerNorm", + target_module=col_nn.LayerNorm1D, + )) + base_policy[BertLayer].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="output.LayerNorm", + target_module=col_nn.LayerNorm1D, + )) + base_policy[BertEmbeddings].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="LayerNorm", + target_module=col_nn.LayerNorm1D, + ),) + return base_policy + def new_model_class(self): # do nothing return self.model @@ -115,9 +147,15 @@ class BertForPretrainingPolicy(BertPolicy): sub_module_replacement=[ SubModuleReplacementDescription(suffix="decoder", target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}) + kwargs={"gather_output": True}), ]) } + if self.shard_config.fused_layernorm: + addon_module[BertLMPredictionHead].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="transform.LayerNorm", + target_module=col_nn.LayerNorm1D, + )) module_policy.update(addon_module) return module_policy @@ -146,9 +184,15 @@ class BertLMHeadModelPolicy(BertPolicy): sub_module_replacement=[ SubModuleReplacementDescription(suffix="decoder", target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}) + kwargs={"gather_output": True}), ]) } + if self.shard_config.fused_layernorm: + addon_module[BertLMPredictionHead].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="transform.LayerNorm", + target_module=col_nn.LayerNorm1D, + )) module_policy.update(addon_module) return module_policy @@ -177,9 +221,15 @@ class BertForMaskedLMPolicy(BertPolicy): sub_module_replacement=[ SubModuleReplacementDescription(suffix="decoder", target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}) + kwargs={"gather_output": True}), ]) } + if self.shard_config.fused_layernorm: + addon_module[BertLMPredictionHead].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="transform.LayerNorm", + target_module=col_nn.LayerNorm1D, + )) module_policy.update(addon_module) return module_policy @@ -199,6 +249,22 @@ class BertForSequenceClassificationPolicy(BertPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self): + module_policy = super().module_policy() + addon_module = { + BertForSequenceClassification: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.Dropout1D, + ) + ]) + } + module_policy.update(addon_module) + return module_policy + # BertForTokenClassification class BertForTokenClassificationPolicy(BertPolicy): @@ -206,6 +272,22 @@ class BertForTokenClassificationPolicy(BertPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self): + module_policy = super().module_policy() + addon_module = { + BertForTokenClassification: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.Dropout1D, + ) + ]) + } + module_policy.update(addon_module) + return module_policy + # BertForNextSentencePrediction class BertForNextSentencePredictionPolicy(BertPolicy): @@ -219,3 +301,19 @@ class BertForMultipleChoicePolicy(BertPolicy): def __init__(self) -> None: super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + addon_module = { + BertForMultipleChoice: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.Dropout1D, + ) + ]) + } + module_policy.update(addon_module) + return module_policy diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 7379a8208..8d3fc225e 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -11,8 +11,9 @@ class ShardConfig: The config for sharding the huggingface model Args: - data_parallel_size (int): The size of data parallel tensor_parallel_size (int): The size of tensor parallel + use_mixedfusedLN (bool): Whether to use the `MixedFusedLayerNorm` + data_parallel_size (int): The size of data parallel pipeline_parallel_size (int): The size of pipeline parallel tensor_parallel_mode (List): The mode of tensor parallel, choose from `['1d','2d','2.5d','3d'] inference_only (bool): Whether to use the inference only mode, when setting to `True`, the model @@ -20,6 +21,7 @@ class ShardConfig: gather_output (bool): Whether to gather the output of the model of the last layer """ tensor_parallel_size: int + fused_layernorm: bool = False # TODO: add support for tensor parallel # pipeline_parallel_size: int diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py new file mode 100644 index 000000000..334ae05be --- /dev/null +++ b/tests/test_shardformer/test_layer/test_layernorm.py @@ -0,0 +1,45 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer import LayerNorm1D +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_layernorm_1d(): + norm = nn.LayerNorm(128, 0.00001).cuda() + norm1d = LayerNorm1D.from_native_module(norm, process_group=None) + + assert norm1d.weight.shape == torch.Size([128]) + + # ensure state dict is reversibly loadable + norm.load_state_dict(norm1d.state_dict()) + norm1d.load_state_dict(norm.state_dict()) + + # check computation correctness + x = torch.rand(4, 128).cuda() + out = norm(x) + gather_out = norm1d(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + assert_close(norm.weight.grad, norm1d.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_layernorm_1d() + + +@rerun_if_address_is_in_use() +def test_layernorm_1d(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_layernorm_1d() diff --git a/tests/test_shardformer/test_layer/test_linearconv_1d.py b/tests/test_shardformer/test_layer/test_linearconv_1d.py index efdb88351..774e6340e 100644 --- a/tests/test_shardformer/test_layer/test_linearconv_1d.py +++ b/tests/test_shardformer/test_layer/test_linearconv_1d.py @@ -77,7 +77,7 @@ def check_linear_conv_1d_col(): assert_close(target_grad, linear_conv_col.weight.grad) -def check_linear_1d_row(): +def check_linear_conv_1d_row(): linear = Conv1D(192, 48).cuda() linear_row = LinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) @@ -103,7 +103,7 @@ def check_linear_1d_row(): def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') check_linear_conv_1d_col() - check_linear_1d_row() + check_linear_conv_1d_row() @rerun_if_address_is_in_use() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 52ca7fce8..a282e0bb9 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -8,7 +8,7 @@ def build_model(world_size, model_fn): org_model = model_fn().cuda() # shard model - shard_config = ShardConfig(tensor_parallel_size=world_size) + shard_config = ShardConfig(tensor_parallel_size=world_size, fused_layernorm=True) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) shard_former.init_distributed()