You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/shardformer/policies/bert.py

294 lines
12 KiB

import torch.nn as nn
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy',
'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy',
'BertForMultipleChoicePolicy'
]
class BertPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self):
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer
policy = {}
if self.shard_config.enable_tensor_parallelism:
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
"attention.self.all_head_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"crossattention.self.all_head_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attention.self.num_attention_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"crossattention.self.num_attention_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
policy[BertEmbeddings] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
)
])
# optimization configuration
if self.shard_config.enable_fused_normalization:
# Handle bert layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=BertLayer)
# handle embedding layer
self.append_or_create_submodule_replacement(
description=[SubModuleReplacementDescription(
suffix="LayerNorm",
target_module=col_nn.FusedLayerNorm,
)],
policy=policy,
target_key=BertEmbeddings)
return policy
def add_lm_head_policy(self, base_policy):
from transformers.models.bert.modeling_bert import BertLMPredictionHead
# optimize for tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
policy=base_policy,
target_key=BertLMPredictionHead)
# optimize with fused normalization
if self.shard_config.enable_fused_normalization:
# Handle bert lm prediction head
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
policy=base_policy,
target_key=BertLMPredictionHead)
return base_policy
def postprocess(self):
return self.model
# BertModel
class BertModelPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
# BertForPreTraining
class BertForPretrainingPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
module_policy = self.add_lm_head_policy(module_policy)
return module_policy
def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model
# BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
module_policy = self.add_lm_head_policy(module_policy)
return module_policy
def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model
# BertForMaskedLM
class BertForMaskedLMPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
module_policy = self.add_lm_head_policy(module_policy)
return module_policy
def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model
# BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForSequenceClassification
module_policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
addon_module = {
BertForSequenceClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
}
module_policy.update(addon_module)
return module_policy
# BertForTokenClassification
class BertForTokenClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForTokenClassification
module_policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
addon_module = {
BertForTokenClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
}
module_policy.update(addon_module)
return module_policy
# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
# BertForMultipleChoice
class BertForMultipleChoicePolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForMultipleChoice
module_policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
addon_module = {
BertForMultipleChoice:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
}
module_policy.update(addon_module)
return module_policy