[shardformer] fix bert and gpt downstream with new api (#4024)

* fix bert downstream with new api

* remove comment line
pull/4157/head
FoolPlayer 2023-06-19 10:47:16 +08:00 committed by Frank Lee
parent e253a07007
commit 74d176c8d8
6 changed files with 97 additions and 39 deletions

View File

@ -76,6 +76,7 @@ class Policy(ABC):
def __init__(self) -> None:
self.model = None
self.shard_config = None
def set_model(self, model: nn.Module) -> None:
r"""
@ -86,14 +87,23 @@ class Policy(ABC):
"""
self.model = model
def set_shard_config(self, shard_config: ShardConfig) -> None:
r"""
Set shard config as an attribute of the Policy object.
Args:
shard_config (:class:`ShardConfig`): The shard config to be perform
"""
self.shard_config = shard_config
@abstractmethod
def preprocess(self, shard_config: ShardConfig = None) -> nn.Module:
def preprocess(self) -> nn.Module:
r"""
Perform some preprocessing of the model, like reshaping the embedding layer
"""
@abstractmethod
def module_policy(self, shard_config: ShardConfig = None) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
r"""
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer

View File

@ -4,41 +4,40 @@ from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, Be
import colossalai.shardformer.layer.layers as col_nn
from colossalai.shardformer.layer.dropout import Dropout1D
from ..shard.shard_config import ShardConfig
from ..utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class BertPolicy(Policy):
def preprocess(self, shard_config: ShardConfig = None):
def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
vocab_size = self.model.config.vocab_size
world_size = shard_config.tensor_parallel_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self, shard_config: ShardConfig = None):
def module_policy(self):
return {
BertLayer:
ModulePolicyDescription(
attribute_replacement={
# 1. shard hidden size
"attention.self.all_head_size":
self.model.config.hidden_size // shard_config.tensor_parallel_size,
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"crossattention.self.all_head_size":
self.model.config.hidden_size // shard_config.tensor_parallel_size,
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
# 2. shard number of heads
"attention.self.num_attention_heads":
self.model.config.num_attention_heads // shard_config.tensor_parallel_size,
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"crossattention.self.num_attention_heads":
self.model.config.num_attention_heads // shard_config.tensor_parallel_size,
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
@ -100,13 +99,43 @@ class BertPolicy(Policy):
return self.model
# BertModel
class BertModelPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
# BertForPreTraining
class BertForPretrainingPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
return module_policy
# BertForMaskedLM
class BertForMaskedLMPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self, shard_config: ShardConfig = None):
module_policy = super().module_policy(shard_config)
def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
@ -124,16 +153,41 @@ class BertForMaskedLMPolicy(BertPolicy):
# BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
base_argument = BertPolicy.argument_policy(config, world_size)
argument = {
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
BertPolicy.unembedding,
]),
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
}
argument.update(base_argument)
return argument
module_policy.update(addon_module)
return module_policy
# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
# BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
# BertForMultipleChoice
class BertForMultipleChoicePolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()

View File

@ -18,10 +18,10 @@ class ShardConfig:
will not calculate the loss and just return the output.
gather_output (bool): Whether to gather the output of the model of the last layer
"""
data_parallel_size: int
tensor_parallel_size: int
pipeline_parallel_size: int
# TODO: add support for tensor parallel
# pipeline_parallel_size: int
# data_parallel_size: int
tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
inference_only: bool = True
gather_output: bool = True

View File

@ -40,6 +40,7 @@ class ModelSharder(object):
Shard the model according to the policy
"""
self.policy.set_model(self.model)
self.policy.set_shard_config(self.shard_config)
self.preprocess()
self.replace_model_class()
self.replace_module()
@ -57,12 +58,12 @@ class ModelSharder(object):
self.model_config = self.model.config
def preprocess(self) -> None:
self.model = self.policy.preprocess(self.shard_config)
self.model = self.policy.preprocess()
def postprocess(self) -> None:
self.model = self.policy.postprocess()
def replace_model_class(self,) -> None:
def replace_model_class(self) -> None:
r"""
Replace the model to policy defined model
Mainly modify the forward and backward to fit distributed model
@ -83,14 +84,14 @@ class ModelSharder(object):
getattr(new_model_class, key),
)
def replace_module(self,) -> None:
def replace_module(self) -> None:
r"""
Replace the module according to the policy, and replace the module one by one
Args:
model (:class:`torch.nn.Module`): The model to shard
"""
module_descriptions = self.policy.module_policy(self.shard_config)
module_descriptions = self.policy.module_policy()
for module_description in module_descriptions.items():
origin_layer_cls = module_description[0]
attr_replacement = module_description[1].attribute_replacement

View File

@ -25,11 +25,7 @@ class ShardFormer:
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
shard_config = ShardConfig(
tensor_parallel_size=2,
data_parallel_size=1,
pipeline_parallel_size=1,
tensor_parallel_mode='1d',
inference_only=True,
gather_output=True
)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()

View File

@ -7,7 +7,6 @@ from transformers import (
AutoTokenizer,
BertConfig,
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
BertForPreTraining,
BertForSequenceClassification,
@ -36,12 +35,10 @@ def build_model(rank, world_size, model):
org_model.to('cuda')
# TODO: no need to transfer to cuda
org_model_forshard.to('cuda')
shard_config = ShardConfig(tensor_parallel_size=2,
data_parallel_size=1,
pipeline_parallel_size=1,
tensor_parallel_mode='1d',
inference_only=True,
gather_output=True)
shard_config = ShardConfig(
tensor_parallel_size=2,
tensor_parallel_mode='1d',
)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')