|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
import colossalai.shardformer.layer as col_nn
|
|
|
|
|
|
|
|
from .._utils import getattr_, setattr_
|
|
|
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy',
|
|
|
|
'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy',
|
|
|
|
'BertForMultipleChoicePolicy'
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
class BertPolicy(Policy):
|
|
|
|
|
|
|
|
def config_sanity_check(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def preprocess(self):
|
|
|
|
# reshape the embedding layer
|
|
|
|
r"""
|
|
|
|
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
|
|
|
"""
|
|
|
|
# TODO:
|
|
|
|
vocab_size = self.model.config.vocab_size
|
|
|
|
world_size = self.shard_config.tensor_parallel_size
|
|
|
|
if vocab_size % world_size != 0:
|
|
|
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
|
|
|
self.model.resize_token_embeddings(new_vocab_size)
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
def module_policy(self):
|
|
|
|
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer
|
|
|
|
|
|
|
|
policy = {}
|
|
|
|
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
|
|
|
|
"attention.self.all_head_size":
|
|
|
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
|
|
|
"crossattention.self.all_head_size":
|
|
|
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
|
|
|
"attention.self.num_attention_heads":
|
|
|
|
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
|
|
|
"crossattention.self.num_attention_heads":
|
|
|
|
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
|
|
|
},
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.self.query",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.self.key",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.self.value",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.self.dropout",
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.output.dense",
|
|
|
|
target_module=col_nn.Linear1D_Row,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.output.dropout",
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="intermediate.dense",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="output.dense",
|
|
|
|
target_module=col_nn.Linear1D_Row,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="output.dropout",
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
)
|
|
|
|
])
|
|
|
|
|
|
|
|
policy[BertEmbeddings] = ModulePolicyDescription(sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="word_embeddings",
|
|
|
|
target_module=col_nn.VocabParallelEmbedding1D,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="dropout",
|
|
|
|
target_module=col_nn.DropoutForReplicatedInput,
|
|
|
|
)
|
|
|
|
])
|
|
|
|
|
|
|
|
# optimization configuration
|
|
|
|
if self.shard_config.enable_fused_normalization:
|
|
|
|
# Handle bert layer
|
|
|
|
self.append_or_create_submodule_replacement(description=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.output.LayerNorm",
|
|
|
|
target_module=col_nn.FusedLayerNorm,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="output.LayerNorm",
|
|
|
|
target_module=col_nn.FusedLayerNorm,
|
|
|
|
)
|
|
|
|
],
|
|
|
|
policy=policy,
|
|
|
|
target_key=BertLayer)
|
|
|
|
|
|
|
|
# handle embedding layer
|
|
|
|
self.append_or_create_submodule_replacement(
|
|
|
|
description=[SubModuleReplacementDescription(
|
|
|
|
suffix="LayerNorm",
|
|
|
|
target_module=col_nn.FusedLayerNorm,
|
|
|
|
)],
|
|
|
|
policy=policy,
|
|
|
|
target_key=BertEmbeddings)
|
|
|
|
return policy
|
|
|
|
|
|
|
|
def add_lm_head_policy(self, base_policy):
|
|
|
|
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
|
|
|
|
|
|
|
# optimize for tensor parallelism
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
|
|
|
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
|
|
|
|
policy=base_policy,
|
|
|
|
target_key=BertLMPredictionHead)
|
|
|
|
|
|
|
|
# optimize with fused normalization
|
|
|
|
if self.shard_config.enable_fused_normalization:
|
|
|
|
# Handle bert lm prediction head
|
|
|
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
|
|
|
suffix="transform.LayerNorm",
|
|
|
|
target_module=col_nn.FusedLayerNorm,
|
|
|
|
),
|
|
|
|
policy=base_policy,
|
|
|
|
target_key=BertLMPredictionHead)
|
|
|
|
return base_policy
|
|
|
|
|
|
|
|
def postprocess(self):
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
|
|
|
|
# BertModel
|
|
|
|
class BertModelPolicy(BertPolicy):
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
# BertForPreTraining
|
|
|
|
class BertForPretrainingPolicy(BertPolicy):
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def module_policy(self):
|
|
|
|
module_policy = super().module_policy()
|
|
|
|
module_policy = self.add_lm_head_policy(module_policy)
|
|
|
|
return module_policy
|
|
|
|
|
|
|
|
def postprocess(self):
|
|
|
|
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
|
|
|
for k, v in binding_map.items():
|
|
|
|
param = getattr_(self.model, k)
|
|
|
|
setattr_(self.model, v, param)
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
|
|
|
|
# BertLMHeadModel
|
|
|
|
class BertLMHeadModelPolicy(BertPolicy):
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def module_policy(self):
|
|
|
|
module_policy = super().module_policy()
|
|
|
|
module_policy = self.add_lm_head_policy(module_policy)
|
|
|
|
return module_policy
|
|
|
|
|
|
|
|
def postprocess(self):
|
|
|
|
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
|
|
|
for k, v in binding_map.items():
|
|
|
|
param = getattr_(self.model, k)
|
|
|
|
setattr_(self.model, v, param)
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
|
|
|
|
# BertForMaskedLM
|
|
|
|
class BertForMaskedLMPolicy(BertPolicy):
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def module_policy(self):
|
|
|
|
module_policy = super().module_policy()
|
|
|
|
module_policy = self.add_lm_head_policy(module_policy)
|
|
|
|
return module_policy
|
|
|
|
|
|
|
|
def postprocess(self):
|
|
|
|
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
|
|
|
for k, v in binding_map.items():
|
|
|
|
param = getattr_(self.model, k)
|
|
|
|
setattr_(self.model, v, param)
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
|
|
|
|
# BertForSequenceClassification
|
|
|
|
class BertForSequenceClassificationPolicy(BertPolicy):
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def module_policy(self):
|
|
|
|
from transformers.models.bert.modeling_bert import BertForSequenceClassification
|
|
|
|
|
|
|
|
module_policy = super().module_policy()
|
|
|
|
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
addon_module = {
|
|
|
|
BertForSequenceClassification:
|
|
|
|
ModulePolicyDescription(sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="dropout",
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
)
|
|
|
|
])
|
|
|
|
}
|
|
|
|
module_policy.update(addon_module)
|
|
|
|
return module_policy
|
|
|
|
|
|
|
|
|
|
|
|
# BertForTokenClassification
|
|
|
|
class BertForTokenClassificationPolicy(BertPolicy):
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def module_policy(self):
|
|
|
|
from transformers.models.bert.modeling_bert import BertForTokenClassification
|
|
|
|
|
|
|
|
module_policy = super().module_policy()
|
|
|
|
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
addon_module = {
|
|
|
|
BertForTokenClassification:
|
|
|
|
ModulePolicyDescription(sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="dropout",
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
)
|
|
|
|
])
|
|
|
|
}
|
|
|
|
module_policy.update(addon_module)
|
|
|
|
return module_policy
|
|
|
|
|
|
|
|
|
|
|
|
# BertForNextSentencePrediction
|
|
|
|
class BertForNextSentencePredictionPolicy(BertPolicy):
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
# BertForMultipleChoice
|
|
|
|
class BertForMultipleChoicePolicy(BertPolicy):
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def module_policy(self):
|
|
|
|
from transformers.models.bert.modeling_bert import BertForMultipleChoice
|
|
|
|
|
|
|
|
module_policy = super().module_policy()
|
|
|
|
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
addon_module = {
|
|
|
|
BertForMultipleChoice:
|
|
|
|
ModulePolicyDescription(sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="dropout",
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
)
|
|
|
|
])
|
|
|
|
}
|
|
|
|
module_policy.update(addon_module)
|
|
|
|
return module_policy
|