diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index d4425497b..e864719ac 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -10,11 +10,31 @@ def build_policies(): """ auto_policy_dict = {} + from transformers import BertModel + + from .bert import BertModelPolicy + auto_policy_dict[BertModel] = BertModelPolicy + + from transformers import BertForPreTraining + + from .bert import BertForPretrainingPolicy + auto_policy_dict[BertForPreTraining] = BertForPretrainingPolicy + + from transformers import BertLMHeadModel + + from .bert import BertLMHeadModelPolicy + auto_policy_dict[BertLMHeadModel] = BertLMHeadModelPolicy + from transformers import BertForMaskedLM from .bert import BertForMaskedLMPolicy auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy + from transformers import BertForNextSentencePrediction + + from .bert import BertForNextSentencePredictionPolicy + auto_policy_dict[BertForNextSentencePrediction] = BertForNextSentencePredictionPolicy + from transformers import BertForSequenceClassification from .bert import BertForSequenceClassificationPolicy @@ -34,6 +54,11 @@ def build_policies(): from .llama import LlamaForCausalLMPolicy auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy + from transformers import BertForMultipleChoice + + from .bert import BertForMultipleChoicePolicy + auto_policy_dict[BertForMultipleChoice] = BertForMultipleChoicePolicy + from transformers import GPT2Model from .gpt2 import GPT2Policy diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 67e910d52..ba2266353 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -35,12 +35,6 @@ class BertPolicy(Policy): ]), } - @staticmethod - def binding_policy(): - return { - "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", - } - @staticmethod def attn_in(): return [ @@ -148,30 +142,6 @@ class BertPolicy(Policy): replace_layer=col_nn.VocabParallelEmbedding1D, )] - -from transformers import BertForMaskedLM - -from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_ - - -class BertForMaskedLMPolicy(BertPolicy): - - @staticmethod - def argument_policy(config, world_size): - base_argument = BertPolicy.argument_policy(config, world_size) - argument = { - BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ - BertForMaskedLMPolicy.unembedding, - ]), - } - argument.update(base_argument) - return argument - - @staticmethod - def inject_policy(): - # return (BertForMaskedLM, BertForMaskedLM_) - return None - @staticmethod def unembedding(): return [ @@ -185,8 +155,112 @@ class BertForMaskedLMPolicy(BertPolicy): ] -class BertForSequenceClassificationPolicy(BertPolicy): +# BertModel +class BertModelPolicy(BertPolicy): + + @staticmethod + def argument_policy(config, world_size): + return BertPolicy.argument_policy(config, world_size) + + +# BertForPretraining +class BertForPretrainingPolicy(BertPolicy): + + @staticmethod + def argument_policy(config, world_size): + base_argument = BertPolicy.argument_policy(config, world_size) + argument = { + BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ + BertPolicy.unembedding, + ]), + } + argument.update(base_argument) + return argument @staticmethod def inject_policy(): return None + + @staticmethod + def binding_policy(): + return { + "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", + } + + +# BertForMaskedLM +from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_ + + +class BertForMaskedLMPolicy(BertPolicy): + + @staticmethod + def argument_policy(config, world_size): + base_argument = BertPolicy.argument_policy(config, world_size) + argument = { + BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ + BertPolicy.unembedding, + ]), + } + argument.update(base_argument) + return argument + + @staticmethod + def inject_policy(): + # return (BertForMaskedLM, BertForMaskedLM_) + return None + + @staticmethod + def binding_policy(): + return { + "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", + } + + +# BertLMHeadModel +class BertLMHeadModelPolicy(BertPolicy): + + @staticmethod + def argument_policy(config, world_size): + base_argument = BertPolicy.argument_policy(config, world_size) + argument = { + BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ + BertPolicy.unembedding, + ]), + } + argument.update(base_argument) + return argument + + @staticmethod + def inject_policy(): + return None + + @staticmethod + def binding_policy(): + return { + "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", + } + + +# BertForNextSentencePrediction +class BertForNextSentencePredictionPolicy(BertPolicy): + + @staticmethod + def argument_policy(config, world_size): + return BertPolicy.argument_policy(config, world_size) + + +# BertForSequenceClassification +class BertForSequenceClassificationPolicy(BertPolicy): + + @staticmethod + def argument_policy(config, world_size): + return BertPolicy.argument_policy(config, world_size) + + +# BertForMultipleChoice +class BertForMultipleChoicePolicy(BertPolicy): + + @staticmethod + def argument_policy(config, world_size): + return BertPolicy.argument_policy(config, world_size) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index e8d6f3408..96c287577 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -13,6 +13,6 @@ class ShardConfig: world_size (int): The world size of the distributed process gather_output (bool): Whether to gather the output of the model of the last layer """ - rank: int - world_size: int = 2 + rank: int = None + world_size: int = None gather_output: bool = True diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 8f6514cb4..7ef0c37a4 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -276,6 +276,7 @@ def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Poli shard_config (`ShardConfig`): the config for distribute information policy (`Policy`): the custom policy for sharding """ + # TODO: init shard_config automatically sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy) sharder.shard() return model diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 55b78d040..9b29111ea 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -1,9 +1,19 @@ +import copy import os -import random import pytest import torch -from transformers import AutoTokenizer, BertConfig, BertForMaskedLM +from transformers import ( + AutoTokenizer, + BertConfig, + BertForMaskedLM, + BertForMultipleChoice, + BertForNextSentencePrediction, + BertForPreTraining, + BertForSequenceClassification, + BertLMHeadModel, + BertModel, +) import colossalai from colossalai.logging import disable_existing_loggers @@ -15,20 +25,21 @@ CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')), tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") -def build_model(rank, world_size): +def build_model(rank, world_size, model): config = BertConfig.from_pretrained('bert-base-uncased') config.hidden_dropout_prob = 0 config.attention_probs_dropout_prob = 0 - org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config).to('cuda') + org_model = model(config=config) + org_model_forshard = copy.deepcopy(org_model) + org_model = org_model.to('cuda') shardconfig = ShardConfig( rank=rank, world_size=world_size, gather_output=True, ) - sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased', config=config), - shardconfig).to('cuda') + sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda') return org_model, sharded_model @@ -85,12 +96,19 @@ def check_backward(org_model, sharded_model): def check_bert(rank, world_size, port): disable_existing_loggers() colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + forward_list = [ + BertModel, BertForPreTraining, BertForMaskedLM, BertLMHeadModel, BertForNextSentencePrediction, + BertForSequenceClassification + ] + backward_lsit = [BertForMaskedLM, BertLMHeadModel] - org_model, sharded_model = build_model(rank, world_size) - check_forward(org_model, sharded_model) - check_backward(org_model, sharded_model) + for model in forward_list: + org_model, sharded_model = build_model(rank, world_size, model) + check_forward(org_model, sharded_model) + if model in backward_lsit: + check_backward(org_model, sharded_model) - torch.cuda.empty_cache() + torch.cuda.empty_cache() @pytest.mark.dist