From df018fc305c1401c26d00e4e03e0e11b24649a21 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 16 Jun 2023 16:12:27 +0800 Subject: [PATCH] support bert with new api --- colossalai/shardformer/policies/bert.py | 35 ++++++++++++++++++++++++- colossalai/shardformer/shard/sharder.py | 5 ++-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index fc3e84473..fe74f83ca 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -2,6 +2,7 @@ import torch.nn as nn from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead import colossalai.shardformer.layer.layers as col_nn +from colossalai.shardformer.layer.dropout import Dropout1D from ..shard.shard_config import ShardConfig from ..utils import getattr_, setattr_ @@ -65,7 +66,24 @@ class BertPolicy(Policy): suffix="output.dense", target_module=col_nn.Linear1D_Row, ), - ]) + SubModuleReplacementDescription( + suffix="attention.self.dropout", + target_module=Dropout1D, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=Dropout1D, + ) + ]), + BertEmbeddings: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ) + ]) } def new_model_class(self): @@ -87,6 +105,21 @@ class BertForMaskedLMPolicy(BertPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self, shard_config: ShardConfig = None): + module_policy = super().module_policy(shard_config) + addon_module = { + BertLMPredictionHead: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="decoder", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) + return module_policy + # BertLMHeadModel class BertLMHeadModelPolicy(BertPolicy): diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index eb8300d59..5c8584595 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -171,12 +171,13 @@ class ModelSharder(object): for description in sub_module_replacement: suffix = description.suffix target_module = description.target_module - kwargs = description.kwargs + kwargs = {} if description.kwargs is None else description.kwargs assert target_module is not None, 'target_module should not be None' # TODO: support different parallel mode native_sub_module = getattr_(org_layer, suffix) - replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d']) + replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'], + **kwargs) setattr_(org_layer, suffix, replace_layer)