[Shardformer] Downstream bert (#3979)

* add dist dropout in model

* update docstring and bert policy with dropout

* refactor basepolicy and sharded, update bert

* update format

* update gpt2 policy

* update bert policy

* remove unused code

* update readme for new policy usage

* add downstream model of bert

* remove unused code
pull/4157/head
FoolPlayer 2023-06-15 17:56:51 +08:00 committed by Frank Lee
parent c1c672d0f0
commit f7774ec0f3
5 changed files with 161 additions and 43 deletions

View File

@ -10,11 +10,31 @@ def build_policies():
"""
auto_policy_dict = {}
from transformers import BertModel
from .bert import BertModelPolicy
auto_policy_dict[BertModel] = BertModelPolicy
from transformers import BertForPreTraining
from .bert import BertForPretrainingPolicy
auto_policy_dict[BertForPreTraining] = BertForPretrainingPolicy
from transformers import BertLMHeadModel
from .bert import BertLMHeadModelPolicy
auto_policy_dict[BertLMHeadModel] = BertLMHeadModelPolicy
from transformers import BertForMaskedLM
from .bert import BertForMaskedLMPolicy
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
from transformers import BertForNextSentencePrediction
from .bert import BertForNextSentencePredictionPolicy
auto_policy_dict[BertForNextSentencePrediction] = BertForNextSentencePredictionPolicy
from transformers import BertForSequenceClassification
from .bert import BertForSequenceClassificationPolicy
@ -34,6 +54,11 @@ def build_policies():
from .llama import LlamaForCausalLMPolicy
auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy
from transformers import BertForMultipleChoice
from .bert import BertForMultipleChoicePolicy
auto_policy_dict[BertForMultipleChoice] = BertForMultipleChoicePolicy
from transformers import GPT2Model
from .gpt2 import GPT2Policy

View File

@ -35,12 +35,6 @@ class BertPolicy(Policy):
]),
}
@staticmethod
def binding_policy():
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
@staticmethod
def attn_in():
return [
@ -148,30 +142,6 @@ class BertPolicy(Policy):
replace_layer=col_nn.VocabParallelEmbedding1D,
)]
from transformers import BertForMaskedLM
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
class BertForMaskedLMPolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
base_argument = BertPolicy.argument_policy(config, world_size)
argument = {
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
BertForMaskedLMPolicy.unembedding,
]),
}
argument.update(base_argument)
return argument
@staticmethod
def inject_policy():
# return (BertForMaskedLM, BertForMaskedLM_)
return None
@staticmethod
def unembedding():
return [
@ -185,8 +155,112 @@ class BertForMaskedLMPolicy(BertPolicy):
]
class BertForSequenceClassificationPolicy(BertPolicy):
# BertModel
class BertModelPolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
return BertPolicy.argument_policy(config, world_size)
# BertForPretraining
class BertForPretrainingPolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
base_argument = BertPolicy.argument_policy(config, world_size)
argument = {
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
BertPolicy.unembedding,
]),
}
argument.update(base_argument)
return argument
@staticmethod
def inject_policy():
return None
@staticmethod
def binding_policy():
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
# BertForMaskedLM
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
class BertForMaskedLMPolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
base_argument = BertPolicy.argument_policy(config, world_size)
argument = {
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
BertPolicy.unembedding,
]),
}
argument.update(base_argument)
return argument
@staticmethod
def inject_policy():
# return (BertForMaskedLM, BertForMaskedLM_)
return None
@staticmethod
def binding_policy():
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
# BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
base_argument = BertPolicy.argument_policy(config, world_size)
argument = {
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
BertPolicy.unembedding,
]),
}
argument.update(base_argument)
return argument
@staticmethod
def inject_policy():
return None
@staticmethod
def binding_policy():
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
return BertPolicy.argument_policy(config, world_size)
# BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
return BertPolicy.argument_policy(config, world_size)
# BertForMultipleChoice
class BertForMultipleChoicePolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
return BertPolicy.argument_policy(config, world_size)

View File

@ -13,6 +13,6 @@ class ShardConfig:
world_size (int): The world size of the distributed process
gather_output (bool): Whether to gather the output of the model of the last layer
"""
rank: int
world_size: int = 2
rank: int = None
world_size: int = None
gather_output: bool = True

View File

@ -276,6 +276,7 @@ def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Poli
shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding
"""
# TODO: init shard_config automatically
sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy)
sharder.shard()
return model

View File

@ -1,9 +1,19 @@
import copy
import os
import random
import pytest
import torch
from transformers import AutoTokenizer, BertConfig, BertForMaskedLM
from transformers import (
AutoTokenizer,
BertConfig,
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
BertForPreTraining,
BertForSequenceClassification,
BertLMHeadModel,
BertModel,
)
import colossalai
from colossalai.logging import disable_existing_loggers
@ -15,20 +25,21 @@ CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def build_model(rank, world_size):
def build_model(rank, world_size, model):
config = BertConfig.from_pretrained('bert-base-uncased')
config.hidden_dropout_prob = 0
config.attention_probs_dropout_prob = 0
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config).to('cuda')
org_model = model(config=config)
org_model_forshard = copy.deepcopy(org_model)
org_model = org_model.to('cuda')
shardconfig = ShardConfig(
rank=rank,
world_size=world_size,
gather_output=True,
)
sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased', config=config),
shardconfig).to('cuda')
sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda')
return org_model, sharded_model
@ -85,12 +96,19 @@ def check_backward(org_model, sharded_model):
def check_bert(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
forward_list = [
BertModel, BertForPreTraining, BertForMaskedLM, BertLMHeadModel, BertForNextSentencePrediction,
BertForSequenceClassification
]
backward_lsit = [BertForMaskedLM, BertLMHeadModel]
org_model, sharded_model = build_model(rank, world_size)
check_forward(org_model, sharded_model)
check_backward(org_model, sharded_model)
for model in forward_list:
org_model, sharded_model = build_model(rank, world_size, model)
check_forward(org_model, sharded_model)
if model in backward_lsit:
check_backward(org_model, sharded_model)
torch.cuda.empty_cache()
torch.cuda.empty_cache()
@pytest.mark.dist