mirror of https://github.com/hpcaitech/ColossalAI
[Shardformer] Downstream bert (#3979)
* add dist dropout in model * update docstring and bert policy with dropout * refactor basepolicy and sharded, update bert * update format * update gpt2 policy * update bert policy * remove unused code * update readme for new policy usage * add downstream model of bert * remove unused codepull/4157/head
parent
c1c672d0f0
commit
f7774ec0f3
|
@ -10,11 +10,31 @@ def build_policies():
|
|||
"""
|
||||
auto_policy_dict = {}
|
||||
|
||||
from transformers import BertModel
|
||||
|
||||
from .bert import BertModelPolicy
|
||||
auto_policy_dict[BertModel] = BertModelPolicy
|
||||
|
||||
from transformers import BertForPreTraining
|
||||
|
||||
from .bert import BertForPretrainingPolicy
|
||||
auto_policy_dict[BertForPreTraining] = BertForPretrainingPolicy
|
||||
|
||||
from transformers import BertLMHeadModel
|
||||
|
||||
from .bert import BertLMHeadModelPolicy
|
||||
auto_policy_dict[BertLMHeadModel] = BertLMHeadModelPolicy
|
||||
|
||||
from transformers import BertForMaskedLM
|
||||
|
||||
from .bert import BertForMaskedLMPolicy
|
||||
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
|
||||
|
||||
from transformers import BertForNextSentencePrediction
|
||||
|
||||
from .bert import BertForNextSentencePredictionPolicy
|
||||
auto_policy_dict[BertForNextSentencePrediction] = BertForNextSentencePredictionPolicy
|
||||
|
||||
from transformers import BertForSequenceClassification
|
||||
|
||||
from .bert import BertForSequenceClassificationPolicy
|
||||
|
@ -34,6 +54,11 @@ def build_policies():
|
|||
from .llama import LlamaForCausalLMPolicy
|
||||
auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy
|
||||
|
||||
from transformers import BertForMultipleChoice
|
||||
|
||||
from .bert import BertForMultipleChoicePolicy
|
||||
auto_policy_dict[BertForMultipleChoice] = BertForMultipleChoicePolicy
|
||||
|
||||
from transformers import GPT2Model
|
||||
|
||||
from .gpt2 import GPT2Policy
|
||||
|
|
|
@ -35,12 +35,6 @@ class BertPolicy(Policy):
|
|||
]),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def binding_policy():
|
||||
return {
|
||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def attn_in():
|
||||
return [
|
||||
|
@ -148,30 +142,6 @@ class BertPolicy(Policy):
|
|||
replace_layer=col_nn.VocabParallelEmbedding1D,
|
||||
)]
|
||||
|
||||
|
||||
from transformers import BertForMaskedLM
|
||||
|
||||
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
|
||||
|
||||
|
||||
class BertForMaskedLMPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
base_argument = BertPolicy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
|
||||
BertForMaskedLMPolicy.unembedding,
|
||||
]),
|
||||
}
|
||||
argument.update(base_argument)
|
||||
return argument
|
||||
|
||||
@staticmethod
|
||||
def inject_policy():
|
||||
# return (BertForMaskedLM, BertForMaskedLM_)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def unembedding():
|
||||
return [
|
||||
|
@ -185,8 +155,112 @@ class BertForMaskedLMPolicy(BertPolicy):
|
|||
]
|
||||
|
||||
|
||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
# BertModel
|
||||
class BertModelPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
return BertPolicy.argument_policy(config, world_size)
|
||||
|
||||
|
||||
# BertForPretraining
|
||||
class BertForPretrainingPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
base_argument = BertPolicy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
|
||||
BertPolicy.unembedding,
|
||||
]),
|
||||
}
|
||||
argument.update(base_argument)
|
||||
return argument
|
||||
|
||||
@staticmethod
|
||||
def inject_policy():
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def binding_policy():
|
||||
return {
|
||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
|
||||
|
||||
# BertForMaskedLM
|
||||
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
|
||||
|
||||
|
||||
class BertForMaskedLMPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
base_argument = BertPolicy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
|
||||
BertPolicy.unembedding,
|
||||
]),
|
||||
}
|
||||
argument.update(base_argument)
|
||||
return argument
|
||||
|
||||
@staticmethod
|
||||
def inject_policy():
|
||||
# return (BertForMaskedLM, BertForMaskedLM_)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def binding_policy():
|
||||
return {
|
||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
|
||||
|
||||
# BertLMHeadModel
|
||||
class BertLMHeadModelPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
base_argument = BertPolicy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
|
||||
BertPolicy.unembedding,
|
||||
]),
|
||||
}
|
||||
argument.update(base_argument)
|
||||
return argument
|
||||
|
||||
@staticmethod
|
||||
def inject_policy():
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def binding_policy():
|
||||
return {
|
||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
|
||||
|
||||
# BertForNextSentencePrediction
|
||||
class BertForNextSentencePredictionPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
return BertPolicy.argument_policy(config, world_size)
|
||||
|
||||
|
||||
# BertForSequenceClassification
|
||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
return BertPolicy.argument_policy(config, world_size)
|
||||
|
||||
|
||||
# BertForMultipleChoice
|
||||
class BertForMultipleChoicePolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
return BertPolicy.argument_policy(config, world_size)
|
||||
|
|
|
@ -13,6 +13,6 @@ class ShardConfig:
|
|||
world_size (int): The world size of the distributed process
|
||||
gather_output (bool): Whether to gather the output of the model of the last layer
|
||||
"""
|
||||
rank: int
|
||||
world_size: int = 2
|
||||
rank: int = None
|
||||
world_size: int = None
|
||||
gather_output: bool = True
|
||||
|
|
|
@ -276,6 +276,7 @@ def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Poli
|
|||
shard_config (`ShardConfig`): the config for distribute information
|
||||
policy (`Policy`): the custom policy for sharding
|
||||
"""
|
||||
# TODO: init shard_config automatically
|
||||
sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy)
|
||||
sharder.shard()
|
||||
return model
|
||||
|
|
|
@ -1,9 +1,19 @@
|
|||
import copy
|
||||
import os
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, BertConfig, BertForMaskedLM
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BertConfig,
|
||||
BertForMaskedLM,
|
||||
BertForMultipleChoice,
|
||||
BertForNextSentencePrediction,
|
||||
BertForPreTraining,
|
||||
BertForSequenceClassification,
|
||||
BertLMHeadModel,
|
||||
BertModel,
|
||||
)
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
|
@ -15,20 +25,21 @@ CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),
|
|||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
|
||||
def build_model(rank, world_size):
|
||||
def build_model(rank, world_size, model):
|
||||
config = BertConfig.from_pretrained('bert-base-uncased')
|
||||
config.hidden_dropout_prob = 0
|
||||
config.attention_probs_dropout_prob = 0
|
||||
|
||||
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config).to('cuda')
|
||||
org_model = model(config=config)
|
||||
org_model_forshard = copy.deepcopy(org_model)
|
||||
|
||||
org_model = org_model.to('cuda')
|
||||
shardconfig = ShardConfig(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
gather_output=True,
|
||||
)
|
||||
sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased', config=config),
|
||||
shardconfig).to('cuda')
|
||||
sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda')
|
||||
|
||||
return org_model, sharded_model
|
||||
|
||||
|
@ -85,12 +96,19 @@ def check_backward(org_model, sharded_model):
|
|||
def check_bert(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
forward_list = [
|
||||
BertModel, BertForPreTraining, BertForMaskedLM, BertLMHeadModel, BertForNextSentencePrediction,
|
||||
BertForSequenceClassification
|
||||
]
|
||||
backward_lsit = [BertForMaskedLM, BertLMHeadModel]
|
||||
|
||||
org_model, sharded_model = build_model(rank, world_size)
|
||||
check_forward(org_model, sharded_model)
|
||||
check_backward(org_model, sharded_model)
|
||||
for model in forward_list:
|
||||
org_model, sharded_model = build_model(rank, world_size, model)
|
||||
check_forward(org_model, sharded_model)
|
||||
if model in backward_lsit:
|
||||
check_backward(org_model, sharded_model)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
Loading…
Reference in New Issue