mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] supported fused normalization (#4112)
parent
b1c2901530
commit
f3b6aaa6b7
|
@ -1,12 +1,12 @@
|
|||
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
||||
from .embedding import Embedding1D, VocabParallelEmbedding1D
|
||||
from .layernorm import FusedLayerNorm
|
||||
from .linear import Linear1D_Col, Linear1D_Row
|
||||
from .loss import cross_entropy_1d
|
||||
from .normalization import FusedLayerNorm, FusedRMSNorm
|
||||
from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
|
||||
__all__ = [
|
||||
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col',
|
||||
'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d",
|
||||
'FusedLayerNorm'
|
||||
'FusedLayerNorm', 'FusedRMSNorm'
|
||||
]
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ['FusedLayerNorm']
|
||||
__all__ = ['FusedLayerNorm', 'FusedRMSNorm']
|
||||
|
||||
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
||||
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576,
|
||||
|
@ -61,4 +61,44 @@ class FusedLayerNorm():
|
|||
# copy weight and bias
|
||||
layernorm.weight.copy_(module.weight)
|
||||
layernorm.bias.copy_(module.bias)
|
||||
return layernorm
|
||||
return layernorm
|
||||
|
||||
|
||||
class FusedRMSNorm():
|
||||
"""
|
||||
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
'FusedRMSNorm is not implemented as a physical class. '
|
||||
'It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex.'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
|
||||
try:
|
||||
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel'
|
||||
)
|
||||
|
||||
# to check if it is huggingface LlamaRMSNorm
|
||||
if module.__class__.__name__ == "LlamaRMSNorm":
|
||||
normalized_shape = module.weight.shape[0]
|
||||
eps = module.variance_epsilon
|
||||
elementwise_affine = True
|
||||
else:
|
||||
# get the attributes of the module
|
||||
normalized_shape = module.normalized_shape
|
||||
eps = module.eps
|
||||
elementwise_affine = module.elementwise_affine
|
||||
|
||||
rmsnorm = ApexFusedRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
with torch.no_grad():
|
||||
# copy weight and bias
|
||||
rmsnorm.weight.copy_(module.weight)
|
||||
|
||||
return rmsnorm
|
|
@ -98,6 +98,14 @@ class Policy(ABC):
|
|||
shard_config (:class:`ShardConfig`): The shard config to be perform
|
||||
"""
|
||||
self.shard_config = shard_config
|
||||
self.config_sanity_check()
|
||||
|
||||
@abstractmethod
|
||||
def config_sanity_check(self):
|
||||
"""
|
||||
Check if the shard config is valid for the model. Raise an exception if the config is invalid.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def preprocess(self) -> nn.Module:
|
||||
|
|
|
@ -16,6 +16,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes
|
|||
|
||||
class BertPolicy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
|
@ -99,7 +102,8 @@ class BertPolicy(Policy):
|
|||
])
|
||||
}
|
||||
|
||||
if self.shard_config.fused_layernorm:
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
base_policy[BertLayer].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.LayerNorm",
|
||||
|
@ -150,12 +154,16 @@ class BertForPretrainingPolicy(BertPolicy):
|
|||
kwargs={"gather_output": True}),
|
||||
])
|
||||
}
|
||||
if self.shard_config.fused_layernorm:
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="transform.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
))
|
||||
|
||||
# append extra policy
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
||||
|
@ -187,7 +195,7 @@ class BertLMHeadModelPolicy(BertPolicy):
|
|||
kwargs={"gather_output": True}),
|
||||
])
|
||||
}
|
||||
if self.shard_config.fused_layernorm:
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="transform.LayerNorm",
|
||||
|
@ -224,12 +232,15 @@ class BertForMaskedLMPolicy(BertPolicy):
|
|||
kwargs={"gather_output": True}),
|
||||
])
|
||||
}
|
||||
if self.shard_config.fused_layernorm:
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="transform.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
))
|
||||
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
||||
|
@ -316,4 +327,4 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
|||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
return module_policy
|
||||
|
|
|
@ -65,6 +65,9 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int,
|
|||
|
||||
class BloomPolicy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
|
@ -81,7 +84,7 @@ class BloomPolicy(Policy):
|
|||
def module_policy(self):
|
||||
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel
|
||||
|
||||
return {
|
||||
base_policy = {
|
||||
BloomBlock:
|
||||
ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
|
@ -99,7 +102,6 @@ class BloomPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
# kwargs={'n_fused': 3}
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense",
|
||||
|
@ -132,6 +134,31 @@ class BloomPolicy(Policy):
|
|||
])
|
||||
}
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
base_policy[BloomModel].sub_module_replacement.extend([
|
||||
SubModuleReplacementDescription(
|
||||
suffix="ln_f",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings_layernorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
])
|
||||
base_policy[BloomBlock].sub_module_replacement.extend([
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
])
|
||||
|
||||
return base_policy
|
||||
|
||||
def new_model_class(self):
|
||||
# do nothing
|
||||
return self.model
|
||||
|
|
|
@ -9,6 +9,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes
|
|||
|
||||
class GPT2Policy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
|
@ -22,7 +25,7 @@ class GPT2Policy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
return {
|
||||
base_policy = {
|
||||
GPT2Model:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
|
@ -77,6 +80,30 @@ class GPT2Policy(Policy):
|
|||
])
|
||||
}
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
base_policy[GPT2Model].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="ln_f",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
))
|
||||
|
||||
base_policy[GPT2Block].sub_module_replacement.extend([
|
||||
SubModuleReplacementDescription(
|
||||
suffix="ln_1",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="ln_2",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(suffix="ln_cross_attn",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
ignore_if_not_exist=True)
|
||||
])
|
||||
|
||||
return base_policy
|
||||
|
||||
def new_model_class(self):
|
||||
return self.model
|
||||
|
||||
|
|
|
@ -4,13 +4,16 @@ import torch.nn as nn
|
|||
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
class LlamaPolicy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
|
@ -23,7 +26,7 @@ class LlamaPolicy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
return {
|
||||
base_policy = {
|
||||
LlamaDecoderLayer:
|
||||
ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
|
@ -75,6 +78,27 @@ class LlamaPolicy(Policy):
|
|||
])
|
||||
}
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
base_policy[LlamaDecoderLayer].sub_module_replacement.extend([
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
)
|
||||
])
|
||||
|
||||
base_policy[LlamaModel].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=FusedRMSNorm,
|
||||
))
|
||||
|
||||
return base_policy
|
||||
|
||||
def new_model_class(self):
|
||||
return None
|
||||
|
||||
|
|
|
@ -13,6 +13,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes
|
|||
|
||||
class OPTPolicy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
|
@ -74,7 +77,9 @@ class OPTPolicy(Policy):
|
|||
),
|
||||
]),
|
||||
}
|
||||
if self.shard_config.fused_layernorm:
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
base_policy[OPTDecoder].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(suffix="final_layer_norm",
|
||||
target_module=FusedLayerNorm,
|
||||
|
@ -87,6 +92,7 @@ class OPTPolicy(Policy):
|
|||
target_module=FusedLayerNorm,
|
||||
ignore_if_not_exist=True)
|
||||
])
|
||||
|
||||
return base_policy
|
||||
|
||||
def new_model_class(self):
|
||||
|
|
|
@ -9,7 +9,7 @@ from transformers.models.t5.modeling_t5 import (
|
|||
T5Stack,
|
||||
)
|
||||
|
||||
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, FusedRMSNorm, Linear1D_Col, Linear1D_Row
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
@ -18,6 +18,9 @@ __all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy
|
|||
|
||||
class T5ModelPolicy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
|
@ -31,7 +34,7 @@ class T5ModelPolicy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
return {
|
||||
base_policy = {
|
||||
T5Stack:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
|
@ -139,6 +142,19 @@ class T5ModelPolicy(Policy):
|
|||
])
|
||||
}
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
base_policy[T5LayerFF].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
|
||||
base_policy[T5LayerSelfAttention].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
|
||||
base_policy[T5LayerCrossAttention].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
|
||||
base_policy[T5Stack].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm))
|
||||
|
||||
return base_policy
|
||||
|
||||
def new_model_class(self):
|
||||
return None
|
||||
|
||||
|
@ -167,4 +183,4 @@ class T5ForConditionalGenerationPolicy(T5ModelPolicy):
|
|||
|
||||
|
||||
class T5EncoderPolicy(T5ModelPolicy):
|
||||
pass
|
||||
pass
|
||||
|
|
|
@ -3,13 +3,16 @@ from typing import Dict, Union
|
|||
import torch.nn as nn
|
||||
from transformers.models.vit.modeling_vit import ViTAttention, ViTEmbeddings, ViTLayer, ViTModel
|
||||
|
||||
from colossalai.shardformer.layer import DropoutForReplicatedInput, Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
class ViTPolicy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
|
@ -22,7 +25,7 @@ class ViTPolicy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
return {
|
||||
base_policy = {
|
||||
ViTEmbeddings:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
|
@ -80,6 +83,26 @@ class ViTPolicy(Policy):
|
|||
]),
|
||||
}
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
base_policy[ViTAttention].sub_module_replacement.extend([
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layernorm_before",
|
||||
target_module=FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layernorm_after",
|
||||
target_module=FusedLayerNorm,
|
||||
)
|
||||
])
|
||||
base_policy[ViTModel].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layernorm",
|
||||
target_module=FusedLayerNorm,
|
||||
))
|
||||
|
||||
return base_policy
|
||||
|
||||
def new_model_class(self):
|
||||
return None
|
||||
|
||||
|
|
|
@ -12,16 +12,10 @@ class ShardConfig:
|
|||
|
||||
Args:
|
||||
tensor_parallel_size (int): The size of tensor parallel
|
||||
use_mixedfusedLN (bool): Whether to use the `MixedFusedLayerNorm`
|
||||
data_parallel_size (int): The size of data parallel
|
||||
pipeline_parallel_size (int): The size of pipeline parallel
|
||||
tensor_parallel_mode (List): The mode of tensor parallel, choose from `['1d','2d','2.5d','3d']
|
||||
inference_only (bool): Whether to use the inference only mode, when setting to `True`, the model
|
||||
will not calculate the loss and just return the output.
|
||||
gather_output (bool): Whether to gather the output of the model of the last layer
|
||||
enable_fused_normalization (bool): Whether to use fused layernorm, default is False
|
||||
"""
|
||||
tensor_parallel_size: int
|
||||
fused_layernorm: bool = False
|
||||
enable_fused_normalization: bool = False
|
||||
|
||||
# TODO: add support for tensor parallel
|
||||
# pipeline_parallel_size: int
|
||||
|
|
|
@ -8,11 +8,11 @@ def build_model(world_size, model_fn):
|
|||
org_model = model_fn().cuda()
|
||||
|
||||
# shard model
|
||||
shard_config = ShardConfig(tensor_parallel_size=world_size, fused_layernorm=True)
|
||||
shard_config = ShardConfig(tensor_parallel_size=world_size, enable_fused_normalization=True)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
sharded_model = shard_former.shard_model(model_copy)
|
||||
sharded_model = shard_former.shard_model(model_copy).cuda()
|
||||
|
||||
return org_model, sharded_model
|
||||
|
||||
|
@ -33,4 +33,4 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
|
|||
shard_output = sharded_model(**data)
|
||||
shard_output = output_transform_fn(shard_output)
|
||||
shard_loss = loss_fn(shard_output)
|
||||
return org_output, org_loss, shard_output, shard_loss
|
||||
return org_output, org_loss, shard_output, shard_loss
|
||||
|
|
Loading…
Reference in New Issue