[shardformer] supported fused normalization (#4112)

pull/4157/head
Frank Lee 2023-06-30 09:32:37 +08:00
parent b1c2901530
commit f3b6aaa6b7
12 changed files with 207 additions and 31 deletions

View File

@ -1,12 +1,12 @@
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .layernorm import FusedLayerNorm
from .linear import Linear1D_Col, Linear1D_Row
from .loss import cross_entropy_1d
from .normalization import FusedLayerNorm, FusedRMSNorm
from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
__all__ = [
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col',
'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d",
'FusedLayerNorm'
'FusedLayerNorm', 'FusedRMSNorm'
]

View File

@ -4,7 +4,7 @@
import torch
import torch.nn as nn
__all__ = ['FusedLayerNorm']
__all__ = ['FusedLayerNorm', 'FusedRMSNorm']
FAST_LAYERNORM_SUPPORTED_SIZE = [
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576,
@ -61,4 +61,44 @@ class FusedLayerNorm():
# copy weight and bias
layernorm.weight.copy_(module.weight)
layernorm.bias.copy_(module.bias)
return layernorm
return layernorm
class FusedRMSNorm():
"""
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
"""
def __init__(self) -> None:
raise NotImplementedError(
'FusedRMSNorm is not implemented as a physical class. '
'It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex.'
)
@staticmethod
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
try:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
except ImportError:
raise ImportError(
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel'
)
# to check if it is huggingface LlamaRMSNorm
if module.__class__.__name__ == "LlamaRMSNorm":
normalized_shape = module.weight.shape[0]
eps = module.variance_epsilon
elementwise_affine = True
else:
# get the attributes of the module
normalized_shape = module.normalized_shape
eps = module.eps
elementwise_affine = module.elementwise_affine
rmsnorm = ApexFusedRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
with torch.no_grad():
# copy weight and bias
rmsnorm.weight.copy_(module.weight)
return rmsnorm

View File

@ -98,6 +98,14 @@ class Policy(ABC):
shard_config (:class:`ShardConfig`): The shard config to be perform
"""
self.shard_config = shard_config
self.config_sanity_check()
@abstractmethod
def config_sanity_check(self):
"""
Check if the shard config is valid for the model. Raise an exception if the config is invalid.
"""
pass
@abstractmethod
def preprocess(self) -> nn.Module:

View File

@ -16,6 +16,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes
class BertPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
# reshape the embedding layer
r"""
@ -99,7 +102,8 @@ class BertPolicy(Policy):
])
}
if self.shard_config.fused_layernorm:
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[BertLayer].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
@ -150,12 +154,16 @@ class BertForPretrainingPolicy(BertPolicy):
kwargs={"gather_output": True}),
])
}
if self.shard_config.fused_layernorm:
# optimization configuration
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.FusedLayerNorm,
))
# append extra policy
module_policy.update(addon_module)
return module_policy
@ -187,7 +195,7 @@ class BertLMHeadModelPolicy(BertPolicy):
kwargs={"gather_output": True}),
])
}
if self.shard_config.fused_layernorm:
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
@ -224,12 +232,15 @@ class BertForMaskedLMPolicy(BertPolicy):
kwargs={"gather_output": True}),
])
}
if self.shard_config.fused_layernorm:
# optimization configuration
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.FusedLayerNorm,
))
module_policy.update(addon_module)
return module_policy
@ -316,4 +327,4 @@ class BertForMultipleChoicePolicy(BertPolicy):
])
}
module_policy.update(addon_module)
return module_policy
return module_policy

View File

@ -65,6 +65,9 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int,
class BloomPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
# reshape the embedding layer
r"""
@ -81,7 +84,7 @@ class BloomPolicy(Policy):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel
return {
base_policy = {
BloomBlock:
ModulePolicyDescription(
attribute_replacement={
@ -99,7 +102,6 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
# kwargs={'n_fused': 3}
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
@ -132,6 +134,31 @@ class BloomPolicy(Policy):
])
}
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[BloomModel].sub_module_replacement.extend([
SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="word_embeddings_layernorm",
target_module=col_nn.FusedLayerNorm,
)
])
base_policy[BloomBlock].sub_module_replacement.extend([
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=col_nn.FusedLayerNorm,
)
])
return base_policy
def new_model_class(self):
# do nothing
return self.model

View File

@ -9,6 +9,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes
class GPT2Policy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
# reshape the embedding layer
r"""
@ -22,7 +25,7 @@ class GPT2Policy(Policy):
return self.model
def module_policy(self):
return {
base_policy = {
GPT2Model:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
@ -77,6 +80,30 @@ class GPT2Policy(Policy):
])
}
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[GPT2Model].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
))
base_policy[GPT2Block].sub_module_replacement.extend([
SubModuleReplacementDescription(
suffix="ln_1",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="ln_2",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(suffix="ln_cross_attn",
target_module=col_nn.FusedLayerNorm,
ignore_if_not_exist=True)
])
return base_policy
def new_model_class(self):
return self.model

View File

@ -4,13 +4,16 @@ import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class LlamaPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
# Resize embedding
vocab_size = self.model.config.vocab_size
@ -23,7 +26,7 @@ class LlamaPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
return {
base_policy = {
LlamaDecoderLayer:
ModulePolicyDescription(
attribute_replacement={
@ -75,6 +78,27 @@ class LlamaPolicy(Policy):
])
}
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[LlamaDecoderLayer].sub_module_replacement.extend([
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
)
])
base_policy[LlamaModel].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
))
return base_policy
def new_model_class(self):
return None

View File

@ -13,6 +13,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes
class OPTPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
# reshape the embedding layer
r"""
@ -74,7 +77,9 @@ class OPTPolicy(Policy):
),
]),
}
if self.shard_config.fused_layernorm:
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[OPTDecoder].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="final_layer_norm",
target_module=FusedLayerNorm,
@ -87,6 +92,7 @@ class OPTPolicy(Policy):
target_module=FusedLayerNorm,
ignore_if_not_exist=True)
])
return base_policy
def new_model_class(self):

View File

@ -9,7 +9,7 @@ from transformers.models.t5.modeling_t5 import (
T5Stack,
)
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, FusedRMSNorm, Linear1D_Col, Linear1D_Row
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -18,6 +18,9 @@ __all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy
class T5ModelPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
# reshape the embedding layer
r"""
@ -31,7 +34,7 @@ class T5ModelPolicy(Policy):
return self.model
def module_policy(self):
return {
base_policy = {
T5Stack:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
@ -139,6 +142,19 @@ class T5ModelPolicy(Policy):
])
}
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[T5LayerFF].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
base_policy[T5LayerSelfAttention].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
base_policy[T5LayerCrossAttention].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
base_policy[T5Stack].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm))
return base_policy
def new_model_class(self):
return None
@ -167,4 +183,4 @@ class T5ForConditionalGenerationPolicy(T5ModelPolicy):
class T5EncoderPolicy(T5ModelPolicy):
pass
pass

View File

@ -3,13 +3,16 @@ from typing import Dict, Union
import torch.nn as nn
from transformers.models.vit.modeling_vit import ViTAttention, ViTEmbeddings, ViTLayer, ViTModel
from colossalai.shardformer.layer import DropoutForReplicatedInput, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class ViTPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
# Resize embedding
vocab_size = self.model.config.vocab_size
@ -22,7 +25,7 @@ class ViTPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
return {
base_policy = {
ViTEmbeddings:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
@ -80,6 +83,26 @@ class ViTPolicy(Policy):
]),
}
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[ViTAttention].sub_module_replacement.extend([
SubModuleReplacementDescription(
suffix="layernorm_before",
target_module=FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="layernorm_after",
target_module=FusedLayerNorm,
)
])
base_policy[ViTModel].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="layernorm",
target_module=FusedLayerNorm,
))
return base_policy
def new_model_class(self):
return None

View File

@ -12,16 +12,10 @@ class ShardConfig:
Args:
tensor_parallel_size (int): The size of tensor parallel
use_mixedfusedLN (bool): Whether to use the `MixedFusedLayerNorm`
data_parallel_size (int): The size of data parallel
pipeline_parallel_size (int): The size of pipeline parallel
tensor_parallel_mode (List): The mode of tensor parallel, choose from `['1d','2d','2.5d','3d']
inference_only (bool): Whether to use the inference only mode, when setting to `True`, the model
will not calculate the loss and just return the output.
gather_output (bool): Whether to gather the output of the model of the last layer
enable_fused_normalization (bool): Whether to use fused layernorm, default is False
"""
tensor_parallel_size: int
fused_layernorm: bool = False
enable_fused_normalization: bool = False
# TODO: add support for tensor parallel
# pipeline_parallel_size: int

View File

@ -8,11 +8,11 @@ def build_model(world_size, model_fn):
org_model = model_fn().cuda()
# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size, fused_layernorm=True)
shard_config = ShardConfig(tensor_parallel_size=world_size, enable_fused_normalization=True)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy)
sharded_model = shard_former.shard_model(model_copy).cuda()
return org_model, sharded_model
@ -33,4 +33,4 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
shard_output = sharded_model(**data)
shard_output = output_transform_fn(shard_output)
shard_loss = loss_fn(shard_output)
return org_output, org_loss, shard_output, shard_loss
return org_output, org_loss, shard_output, shard_loss