[shardformer] fix bert and gpt downstream with new api (#4024)

* fix bert downstream with new api

* remove comment line
pull/4157/head
FoolPlayer 2023-06-19 10:47:16 +08:00 committed by Frank Lee
parent e253a07007
commit 74d176c8d8
6 changed files with 97 additions and 39 deletions

View File

@ -76,6 +76,7 @@ class Policy(ABC):
def __init__(self) -> None: def __init__(self) -> None:
self.model = None self.model = None
self.shard_config = None
def set_model(self, model: nn.Module) -> None: def set_model(self, model: nn.Module) -> None:
r""" r"""
@ -86,14 +87,23 @@ class Policy(ABC):
""" """
self.model = model self.model = model
def set_shard_config(self, shard_config: ShardConfig) -> None:
r"""
Set shard config as an attribute of the Policy object.
Args:
shard_config (:class:`ShardConfig`): The shard config to be perform
"""
self.shard_config = shard_config
@abstractmethod @abstractmethod
def preprocess(self, shard_config: ShardConfig = None) -> nn.Module: def preprocess(self) -> nn.Module:
r""" r"""
Perform some preprocessing of the model, like reshaping the embedding layer Perform some preprocessing of the model, like reshaping the embedding layer
""" """
@abstractmethod @abstractmethod
def module_policy(self, shard_config: ShardConfig = None) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
r""" r"""
Return the dict for the modify policy, the key is the original layer class and the value is the Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer argument for the modify layer

View File

@ -4,41 +4,40 @@ from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, Be
import colossalai.shardformer.layer.layers as col_nn import colossalai.shardformer.layer.layers as col_nn
from colossalai.shardformer.layer.dropout import Dropout1D from colossalai.shardformer.layer.dropout import Dropout1D
from ..shard.shard_config import ShardConfig
from ..utils import getattr_, setattr_ from ..utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class BertPolicy(Policy): class BertPolicy(Policy):
def preprocess(self, shard_config: ShardConfig = None): def preprocess(self):
# reshape the embedding layer # reshape the embedding layer
r""" r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size Reshape the Embedding layer to make the embedding dimension divisible by world_size
""" """
# TODO: # TODO:
vocab_size = self.model.config.vocab_size vocab_size = self.model.config.vocab_size
world_size = shard_config.tensor_parallel_size world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0: if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size) self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
def module_policy(self, shard_config: ShardConfig = None): def module_policy(self):
return { return {
BertLayer: BertLayer:
ModulePolicyDescription( ModulePolicyDescription(
attribute_replacement={ attribute_replacement={
# 1. shard hidden size # 1. shard hidden size
"attention.self.all_head_size": "attention.self.all_head_size":
self.model.config.hidden_size // shard_config.tensor_parallel_size, self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"crossattention.self.all_head_size": "crossattention.self.all_head_size":
self.model.config.hidden_size // shard_config.tensor_parallel_size, self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
# 2. shard number of heads # 2. shard number of heads
"attention.self.num_attention_heads": "attention.self.num_attention_heads":
self.model.config.num_attention_heads // shard_config.tensor_parallel_size, self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"crossattention.self.num_attention_heads": "crossattention.self.num_attention_heads":
self.model.config.num_attention_heads // shard_config.tensor_parallel_size, self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
}, },
param_replacement=[], param_replacement=[],
sub_module_replacement=[ sub_module_replacement=[
@ -100,13 +99,43 @@ class BertPolicy(Policy):
return self.model return self.model
# BertModel
class BertModelPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
# BertForPreTraining
class BertForPretrainingPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
return module_policy
# BertForMaskedLM
class BertForMaskedLMPolicy(BertPolicy): class BertForMaskedLMPolicy(BertPolicy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def module_policy(self, shard_config: ShardConfig = None): def module_policy(self):
module_policy = super().module_policy(shard_config) module_policy = super().module_policy()
addon_module = { addon_module = {
BertLMPredictionHead: BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={}, ModulePolicyDescription(attribute_replacement={},
@ -124,16 +153,41 @@ class BertForMaskedLMPolicy(BertPolicy):
# BertLMHeadModel # BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy): class BertLMHeadModelPolicy(BertPolicy):
@staticmethod def __init__(self) -> None:
def argument_policy(config, world_size): super().__init__()
base_argument = BertPolicy.argument_policy(config, world_size)
argument = { def module_policy(self):
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ module_policy = super().module_policy()
BertPolicy.unembedding, addon_module = {
]), BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
} }
argument.update(base_argument) module_policy.update(addon_module)
return argument return module_policy
# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
# BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
# BertForMultipleChoice
class BertForMultipleChoicePolicy(BertPolicy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()

View File

@ -18,10 +18,10 @@ class ShardConfig:
will not calculate the loss and just return the output. will not calculate the loss and just return the output.
gather_output (bool): Whether to gather the output of the model of the last layer gather_output (bool): Whether to gather the output of the model of the last layer
""" """
data_parallel_size: int
tensor_parallel_size: int tensor_parallel_size: int
# TODO: add support for tensor parallel
pipeline_parallel_size: int # pipeline_parallel_size: int
# data_parallel_size: int
tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
inference_only: bool = True inference_only: bool = True
gather_output: bool = True gather_output: bool = True

View File

@ -40,6 +40,7 @@ class ModelSharder(object):
Shard the model according to the policy Shard the model according to the policy
""" """
self.policy.set_model(self.model) self.policy.set_model(self.model)
self.policy.set_shard_config(self.shard_config)
self.preprocess() self.preprocess()
self.replace_model_class() self.replace_model_class()
self.replace_module() self.replace_module()
@ -57,12 +58,12 @@ class ModelSharder(object):
self.model_config = self.model.config self.model_config = self.model.config
def preprocess(self) -> None: def preprocess(self) -> None:
self.model = self.policy.preprocess(self.shard_config) self.model = self.policy.preprocess()
def postprocess(self) -> None: def postprocess(self) -> None:
self.model = self.policy.postprocess() self.model = self.policy.postprocess()
def replace_model_class(self,) -> None: def replace_model_class(self) -> None:
r""" r"""
Replace the model to policy defined model Replace the model to policy defined model
Mainly modify the forward and backward to fit distributed model Mainly modify the forward and backward to fit distributed model
@ -83,14 +84,14 @@ class ModelSharder(object):
getattr(new_model_class, key), getattr(new_model_class, key),
) )
def replace_module(self,) -> None: def replace_module(self) -> None:
r""" r"""
Replace the module according to the policy, and replace the module one by one Replace the module according to the policy, and replace the module one by one
Args: Args:
model (:class:`torch.nn.Module`): The model to shard model (:class:`torch.nn.Module`): The model to shard
""" """
module_descriptions = self.policy.module_policy(self.shard_config) module_descriptions = self.policy.module_policy()
for module_description in module_descriptions.items(): for module_description in module_descriptions.items():
origin_layer_cls = module_description[0] origin_layer_cls = module_description[0]
attr_replacement = module_description[1].attribute_replacement attr_replacement = module_description[1].attribute_replacement

View File

@ -25,11 +25,7 @@ class ShardFormer:
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
shard_config = ShardConfig( shard_config = ShardConfig(
tensor_parallel_size=2, tensor_parallel_size=2,
data_parallel_size=1,
pipeline_parallel_size=1,
tensor_parallel_mode='1d', tensor_parallel_mode='1d',
inference_only=True,
gather_output=True
) )
shard_former = ShardFormer(shard_config=shard_config) shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed() shard_former.init_distributed()

View File

@ -7,7 +7,6 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
BertConfig, BertConfig,
BertForMaskedLM, BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction, BertForNextSentencePrediction,
BertForPreTraining, BertForPreTraining,
BertForSequenceClassification, BertForSequenceClassification,
@ -36,12 +35,10 @@ def build_model(rank, world_size, model):
org_model.to('cuda') org_model.to('cuda')
# TODO: no need to transfer to cuda # TODO: no need to transfer to cuda
org_model_forshard.to('cuda') org_model_forshard.to('cuda')
shard_config = ShardConfig(tensor_parallel_size=2, shard_config = ShardConfig(
data_parallel_size=1, tensor_parallel_size=2,
pipeline_parallel_size=1,
tensor_parallel_mode='1d', tensor_parallel_mode='1d',
inference_only=True, )
gather_output=True)
shard_former = ShardFormer(shard_config=shard_config) shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed() shard_former.init_distributed()
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda') sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')