mirror of https://github.com/hpcaitech/ColossalAI
support bert with new api
parent
507c0ad368
commit
df018fc305
|
@ -2,6 +2,7 @@ import torch.nn as nn
|
||||||
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
|
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
|
||||||
|
|
||||||
import colossalai.shardformer.layer.layers as col_nn
|
import colossalai.shardformer.layer.layers as col_nn
|
||||||
|
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||||
|
|
||||||
from ..shard.shard_config import ShardConfig
|
from ..shard.shard_config import ShardConfig
|
||||||
from ..utils import getattr_, setattr_
|
from ..utils import getattr_, setattr_
|
||||||
|
@ -65,7 +66,24 @@ class BertPolicy(Policy):
|
||||||
suffix="output.dense",
|
suffix="output.dense",
|
||||||
target_module=col_nn.Linear1D_Row,
|
target_module=col_nn.Linear1D_Row,
|
||||||
),
|
),
|
||||||
])
|
SubModuleReplacementDescription(
|
||||||
|
suffix="attention.self.dropout",
|
||||||
|
target_module=Dropout1D,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="attention.output.dropout",
|
||||||
|
target_module=Dropout1D,
|
||||||
|
)
|
||||||
|
]),
|
||||||
|
BertEmbeddings:
|
||||||
|
ModulePolicyDescription(attribute_replacement={},
|
||||||
|
param_replacement=[],
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="word_embeddings",
|
||||||
|
target_module=col_nn.VocabParallelEmbedding1D,
|
||||||
|
)
|
||||||
|
])
|
||||||
}
|
}
|
||||||
|
|
||||||
def new_model_class(self):
|
def new_model_class(self):
|
||||||
|
@ -87,6 +105,21 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def module_policy(self, shard_config: ShardConfig = None):
|
||||||
|
module_policy = super().module_policy(shard_config)
|
||||||
|
addon_module = {
|
||||||
|
BertLMPredictionHead:
|
||||||
|
ModulePolicyDescription(attribute_replacement={},
|
||||||
|
param_replacement=[],
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(suffix="decoder",
|
||||||
|
target_module=col_nn.Linear1D_Col,
|
||||||
|
kwargs={"gather_output": True})
|
||||||
|
])
|
||||||
|
}
|
||||||
|
module_policy.update(addon_module)
|
||||||
|
return module_policy
|
||||||
|
|
||||||
|
|
||||||
# BertLMHeadModel
|
# BertLMHeadModel
|
||||||
class BertLMHeadModelPolicy(BertPolicy):
|
class BertLMHeadModelPolicy(BertPolicy):
|
||||||
|
|
|
@ -171,12 +171,13 @@ class ModelSharder(object):
|
||||||
for description in sub_module_replacement:
|
for description in sub_module_replacement:
|
||||||
suffix = description.suffix
|
suffix = description.suffix
|
||||||
target_module = description.target_module
|
target_module = description.target_module
|
||||||
kwargs = description.kwargs
|
kwargs = {} if description.kwargs is None else description.kwargs
|
||||||
|
|
||||||
assert target_module is not None, 'target_module should not be None'
|
assert target_module is not None, 'target_module should not be None'
|
||||||
|
|
||||||
# TODO: support different parallel mode
|
# TODO: support different parallel mode
|
||||||
native_sub_module = getattr_(org_layer, suffix)
|
native_sub_module = getattr_(org_layer, suffix)
|
||||||
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'])
|
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
setattr_(org_layer, suffix, replace_layer)
|
setattr_(org_layer, suffix, replace_layer)
|
||||||
|
|
Loading…
Reference in New Issue