mirror of https://github.com/hpcaitech/ColossalAI
support bert with new api
parent
507c0ad368
commit
df018fc305
|
@ -2,6 +2,7 @@ import torch.nn as nn
|
|||
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
|
||||
|
||||
import colossalai.shardformer.layer.layers as col_nn
|
||||
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||
|
||||
from ..shard.shard_config import ShardConfig
|
||||
from ..utils import getattr_, setattr_
|
||||
|
@ -65,7 +66,24 @@ class BertPolicy(Policy):
|
|||
suffix="output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
])
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.dropout",
|
||||
target_module=Dropout1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
target_module=Dropout1D,
|
||||
)
|
||||
]),
|
||||
BertEmbeddings:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
}
|
||||
|
||||
def new_model_class(self):
|
||||
|
@ -87,6 +105,21 @@ class BertForMaskedLMPolicy(BertPolicy):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self, shard_config: ShardConfig = None):
|
||||
module_policy = super().module_policy(shard_config)
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="decoder",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True})
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
||||
|
||||
# BertLMHeadModel
|
||||
class BertLMHeadModelPolicy(BertPolicy):
|
||||
|
|
|
@ -171,12 +171,13 @@ class ModelSharder(object):
|
|||
for description in sub_module_replacement:
|
||||
suffix = description.suffix
|
||||
target_module = description.target_module
|
||||
kwargs = description.kwargs
|
||||
kwargs = {} if description.kwargs is None else description.kwargs
|
||||
|
||||
assert target_module is not None, 'target_module should not be None'
|
||||
|
||||
# TODO: support different parallel mode
|
||||
native_sub_module = getattr_(org_layer, suffix)
|
||||
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'])
|
||||
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
|
||||
**kwargs)
|
||||
|
||||
setattr_(org_layer, suffix, replace_layer)
|
||||
|
|
Loading…
Reference in New Issue