support bert with new api

pull/4157/head
FoolPlayer 2023-06-16 16:12:27 +08:00 committed by Frank Lee
parent 507c0ad368
commit df018fc305
2 changed files with 37 additions and 3 deletions

View File

@ -2,6 +2,7 @@ import torch.nn as nn
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
import colossalai.shardformer.layer.layers as col_nn
from colossalai.shardformer.layer.dropout import Dropout1D
from ..shard.shard_config import ShardConfig
from ..utils import getattr_, setattr_
@ -65,7 +66,24 @@ class BertPolicy(Policy):
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
),
])
SubModuleReplacementDescription(
suffix="attention.self.dropout",
target_module=Dropout1D,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=Dropout1D,
)
]),
BertEmbeddings:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
])
}
def new_model_class(self):
@ -87,6 +105,21 @@ class BertForMaskedLMPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self, shard_config: ShardConfig = None):
module_policy = super().module_policy(shard_config)
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
return module_policy
# BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy):

View File

@ -171,12 +171,13 @@ class ModelSharder(object):
for description in sub_module_replacement:
suffix = description.suffix
target_module = description.target_module
kwargs = description.kwargs
kwargs = {} if description.kwargs is None else description.kwargs
assert target_module is not None, 'target_module should not be None'
# TODO: support different parallel mode
native_sub_module = getattr_(org_layer, suffix)
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'])
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
**kwargs)
setattr_(org_layer, suffix, replace_layer)