mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] fix bert and gpt downstream with new api (#4024)
* fix bert downstream with new api * remove comment linepull/4157/head
parent
e253a07007
commit
74d176c8d8
|
@ -76,6 +76,7 @@ class Policy(ABC):
|
|||
|
||||
def __init__(self) -> None:
|
||||
self.model = None
|
||||
self.shard_config = None
|
||||
|
||||
def set_model(self, model: nn.Module) -> None:
|
||||
r"""
|
||||
|
@ -86,14 +87,23 @@ class Policy(ABC):
|
|||
"""
|
||||
self.model = model
|
||||
|
||||
def set_shard_config(self, shard_config: ShardConfig) -> None:
|
||||
r"""
|
||||
Set shard config as an attribute of the Policy object.
|
||||
|
||||
Args:
|
||||
shard_config (:class:`ShardConfig`): The shard config to be perform
|
||||
"""
|
||||
self.shard_config = shard_config
|
||||
|
||||
@abstractmethod
|
||||
def preprocess(self, shard_config: ShardConfig = None) -> nn.Module:
|
||||
def preprocess(self) -> nn.Module:
|
||||
r"""
|
||||
Perform some preprocessing of the model, like reshaping the embedding layer
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def module_policy(self, shard_config: ShardConfig = None) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
r"""
|
||||
Return the dict for the modify policy, the key is the original layer class and the value is the
|
||||
argument for the modify layer
|
||||
|
|
|
@ -4,41 +4,40 @@ from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, Be
|
|||
import colossalai.shardformer.layer.layers as col_nn
|
||||
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||
|
||||
from ..shard.shard_config import ShardConfig
|
||||
from ..utils import getattr_, setattr_
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
class BertPolicy(Policy):
|
||||
|
||||
def preprocess(self, shard_config: ShardConfig = None):
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
# TODO:
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = shard_config.tensor_parallel_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
return self.model
|
||||
|
||||
def module_policy(self, shard_config: ShardConfig = None):
|
||||
def module_policy(self):
|
||||
return {
|
||||
BertLayer:
|
||||
ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
# 1. shard hidden size
|
||||
"attention.self.all_head_size":
|
||||
self.model.config.hidden_size // shard_config.tensor_parallel_size,
|
||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"crossattention.self.all_head_size":
|
||||
self.model.config.hidden_size // shard_config.tensor_parallel_size,
|
||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
# 2. shard number of heads
|
||||
"attention.self.num_attention_heads":
|
||||
self.model.config.num_attention_heads // shard_config.tensor_parallel_size,
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"crossattention.self.num_attention_heads":
|
||||
self.model.config.num_attention_heads // shard_config.tensor_parallel_size,
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
|
@ -100,13 +99,43 @@ class BertPolicy(Policy):
|
|||
return self.model
|
||||
|
||||
|
||||
# BertModel
|
||||
class BertModelPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
# BertForPreTraining
|
||||
class BertForPretrainingPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="decoder",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True})
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
||||
|
||||
# BertForMaskedLM
|
||||
class BertForMaskedLMPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self, shard_config: ShardConfig = None):
|
||||
module_policy = super().module_policy(shard_config)
|
||||
def module_policy(self):
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
|
@ -124,16 +153,41 @@ class BertForMaskedLMPolicy(BertPolicy):
|
|||
# BertLMHeadModel
|
||||
class BertLMHeadModelPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
base_argument = BertPolicy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
|
||||
BertPolicy.unembedding,
|
||||
]),
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="decoder",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True})
|
||||
])
|
||||
}
|
||||
argument.update(base_argument)
|
||||
return argument
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
||||
|
||||
# BertForNextSentencePrediction
|
||||
class BertForNextSentencePredictionPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
# BertForSequenceClassification
|
||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
# BertForMultipleChoice
|
||||
class BertForMultipleChoicePolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
|
@ -18,10 +18,10 @@ class ShardConfig:
|
|||
will not calculate the loss and just return the output.
|
||||
gather_output (bool): Whether to gather the output of the model of the last layer
|
||||
"""
|
||||
data_parallel_size: int
|
||||
tensor_parallel_size: int
|
||||
|
||||
pipeline_parallel_size: int
|
||||
# TODO: add support for tensor parallel
|
||||
# pipeline_parallel_size: int
|
||||
# data_parallel_size: int
|
||||
tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
||||
inference_only: bool = True
|
||||
gather_output: bool = True
|
||||
|
|
|
@ -40,6 +40,7 @@ class ModelSharder(object):
|
|||
Shard the model according to the policy
|
||||
"""
|
||||
self.policy.set_model(self.model)
|
||||
self.policy.set_shard_config(self.shard_config)
|
||||
self.preprocess()
|
||||
self.replace_model_class()
|
||||
self.replace_module()
|
||||
|
@ -57,12 +58,12 @@ class ModelSharder(object):
|
|||
self.model_config = self.model.config
|
||||
|
||||
def preprocess(self) -> None:
|
||||
self.model = self.policy.preprocess(self.shard_config)
|
||||
self.model = self.policy.preprocess()
|
||||
|
||||
def postprocess(self) -> None:
|
||||
self.model = self.policy.postprocess()
|
||||
|
||||
def replace_model_class(self,) -> None:
|
||||
def replace_model_class(self) -> None:
|
||||
r"""
|
||||
Replace the model to policy defined model
|
||||
Mainly modify the forward and backward to fit distributed model
|
||||
|
@ -83,14 +84,14 @@ class ModelSharder(object):
|
|||
getattr(new_model_class, key),
|
||||
)
|
||||
|
||||
def replace_module(self,) -> None:
|
||||
def replace_module(self) -> None:
|
||||
r"""
|
||||
Replace the module according to the policy, and replace the module one by one
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): The model to shard
|
||||
"""
|
||||
module_descriptions = self.policy.module_policy(self.shard_config)
|
||||
module_descriptions = self.policy.module_policy()
|
||||
for module_description in module_descriptions.items():
|
||||
origin_layer_cls = module_description[0]
|
||||
attr_replacement = module_description[1].attribute_replacement
|
||||
|
|
|
@ -25,11 +25,7 @@ class ShardFormer:
|
|||
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||
shard_config = ShardConfig(
|
||||
tensor_parallel_size=2,
|
||||
data_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
tensor_parallel_mode='1d',
|
||||
inference_only=True,
|
||||
gather_output=True
|
||||
)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
|
|
|
@ -7,7 +7,6 @@ from transformers import (
|
|||
AutoTokenizer,
|
||||
BertConfig,
|
||||
BertForMaskedLM,
|
||||
BertForMultipleChoice,
|
||||
BertForNextSentencePrediction,
|
||||
BertForPreTraining,
|
||||
BertForSequenceClassification,
|
||||
|
@ -36,12 +35,10 @@ def build_model(rank, world_size, model):
|
|||
org_model.to('cuda')
|
||||
# TODO: no need to transfer to cuda
|
||||
org_model_forshard.to('cuda')
|
||||
shard_config = ShardConfig(tensor_parallel_size=2,
|
||||
data_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
tensor_parallel_mode='1d',
|
||||
inference_only=True,
|
||||
gather_output=True)
|
||||
shard_config = ShardConfig(
|
||||
tensor_parallel_size=2,
|
||||
tensor_parallel_mode='1d',
|
||||
)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
|
||||
|
|
Loading…
Reference in New Issue