diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index f3431c386..fc3e84473 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -8,12 +8,6 @@ from ..utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -class ParallelModule(): - - def __init__(self): - pass - - class BertPolicy(Policy): def preprocess(self, shard_config: ShardConfig = None): @@ -49,7 +43,27 @@ class BertPolicy(Policy): sub_module_replacement=[ SubModuleReplacementDescription( suffix="attention.self.query", - target_module=ParallelModule, + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.self.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.self.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.Linear1D_Row, ), ]) } diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 8eee3c6a3..eb8300d59 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -7,8 +7,8 @@ from transformers.pytorch_utils import Conv1D from colossalai.cluster.process_group_manager import ProcessGroupManager from ..policies.autopolicy import get_autopolicy -from ..policies.basepolicy import Policy -from ..utils.utils import setattr_ +from ..policies.basepolicy import Policy, SubModuleReplacementDescription +from ..utils.utils import getattr_, setattr_ from .shard_config import ShardConfig __all__ = ['ModelSharder', 'shard_model'] @@ -90,9 +90,7 @@ class ModelSharder(object): Args: model (:class:`torch.nn.Module`): The model to shard """ - print(self.policy) module_descriptions = self.policy.module_policy(self.shard_config) - print(f"*******{module_descriptions}") for module_description in module_descriptions.items(): origin_layer_cls = module_description[0] attr_replacement = module_description[1].attribute_replacement @@ -160,7 +158,7 @@ class ModelSharder(object): def _replace_sub_module( self, org_layer: nn.Module, - sub_module_replacement: List[Callable], + sub_module_replacement: List[SubModuleReplacementDescription], ) -> None: r""" Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict @@ -177,7 +175,8 @@ class ModelSharder(object): assert target_module is not None, 'target_module should not be None' - # TODO: integrate with new layer - # replace_layer = target_module.from_native_layer(org_layer, self.pg_manager) - replace_layer = None + # TODO: support different parallel mode + native_sub_module = getattr_(org_layer, suffix) + replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d']) + setattr_(org_layer, suffix, replace_layer) diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 9b29111ea..05d033436 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -17,7 +17,7 @@ from transformers import ( import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer.shard import ShardConfig, shard_model +from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.testing import rerun_if_address_is_in_use, spawn os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -30,16 +30,21 @@ def build_model(rank, world_size, model): config.hidden_dropout_prob = 0 config.attention_probs_dropout_prob = 0 - org_model = model(config=config) + org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config) org_model_forshard = copy.deepcopy(org_model) - org_model = org_model.to('cuda') - shardconfig = ShardConfig( - rank=rank, - world_size=world_size, - gather_output=True, - ) - sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda') + org_model.to('cuda') + # TODO: no need to transfer to cuda + org_model_forshard.to('cuda') + shard_config = ShardConfig(tensor_parallel_size=2, + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_mode='1d', + inference_only=True, + gather_output=True) + shard_former = ShardFormer(shard_config=shard_config) + shard_former.init_distributed() + sharded_model = shard_former.shard_model(org_model_forshard).to('cuda') return org_model, sharded_model