mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] fix bert and gpt downstream with new api (#4024)
* fix bert downstream with new api * remove comment linepull/4157/head
parent
e253a07007
commit
74d176c8d8
|
@ -76,6 +76,7 @@ class Policy(ABC):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.model = None
|
self.model = None
|
||||||
|
self.shard_config = None
|
||||||
|
|
||||||
def set_model(self, model: nn.Module) -> None:
|
def set_model(self, model: nn.Module) -> None:
|
||||||
r"""
|
r"""
|
||||||
|
@ -86,14 +87,23 @@ class Policy(ABC):
|
||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
|
def set_shard_config(self, shard_config: ShardConfig) -> None:
|
||||||
|
r"""
|
||||||
|
Set shard config as an attribute of the Policy object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shard_config (:class:`ShardConfig`): The shard config to be perform
|
||||||
|
"""
|
||||||
|
self.shard_config = shard_config
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def preprocess(self, shard_config: ShardConfig = None) -> nn.Module:
|
def preprocess(self) -> nn.Module:
|
||||||
r"""
|
r"""
|
||||||
Perform some preprocessing of the model, like reshaping the embedding layer
|
Perform some preprocessing of the model, like reshaping the embedding layer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def module_policy(self, shard_config: ShardConfig = None) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
r"""
|
r"""
|
||||||
Return the dict for the modify policy, the key is the original layer class and the value is the
|
Return the dict for the modify policy, the key is the original layer class and the value is the
|
||||||
argument for the modify layer
|
argument for the modify layer
|
||||||
|
|
|
@ -4,41 +4,40 @@ from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, Be
|
||||||
import colossalai.shardformer.layer.layers as col_nn
|
import colossalai.shardformer.layer.layers as col_nn
|
||||||
from colossalai.shardformer.layer.dropout import Dropout1D
|
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||||
|
|
||||||
from ..shard.shard_config import ShardConfig
|
|
||||||
from ..utils import getattr_, setattr_
|
from ..utils import getattr_, setattr_
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
|
||||||
class BertPolicy(Policy):
|
class BertPolicy(Policy):
|
||||||
|
|
||||||
def preprocess(self, shard_config: ShardConfig = None):
|
def preprocess(self):
|
||||||
# reshape the embedding layer
|
# reshape the embedding layer
|
||||||
r"""
|
r"""
|
||||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||||
"""
|
"""
|
||||||
# TODO:
|
# TODO:
|
||||||
vocab_size = self.model.config.vocab_size
|
vocab_size = self.model.config.vocab_size
|
||||||
world_size = shard_config.tensor_parallel_size
|
world_size = self.shard_config.tensor_parallel_size
|
||||||
if vocab_size % world_size != 0:
|
if vocab_size % world_size != 0:
|
||||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||||
self.model.resize_token_embeddings(new_vocab_size)
|
self.model.resize_token_embeddings(new_vocab_size)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self, shard_config: ShardConfig = None):
|
def module_policy(self):
|
||||||
return {
|
return {
|
||||||
BertLayer:
|
BertLayer:
|
||||||
ModulePolicyDescription(
|
ModulePolicyDescription(
|
||||||
attribute_replacement={
|
attribute_replacement={
|
||||||
# 1. shard hidden size
|
# 1. shard hidden size
|
||||||
"attention.self.all_head_size":
|
"attention.self.all_head_size":
|
||||||
self.model.config.hidden_size // shard_config.tensor_parallel_size,
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
"crossattention.self.all_head_size":
|
"crossattention.self.all_head_size":
|
||||||
self.model.config.hidden_size // shard_config.tensor_parallel_size,
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
# 2. shard number of heads
|
# 2. shard number of heads
|
||||||
"attention.self.num_attention_heads":
|
"attention.self.num_attention_heads":
|
||||||
self.model.config.num_attention_heads // shard_config.tensor_parallel_size,
|
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
"crossattention.self.num_attention_heads":
|
"crossattention.self.num_attention_heads":
|
||||||
self.model.config.num_attention_heads // shard_config.tensor_parallel_size,
|
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
},
|
},
|
||||||
param_replacement=[],
|
param_replacement=[],
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
|
@ -100,13 +99,43 @@ class BertPolicy(Policy):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
|
# BertModel
|
||||||
|
class BertModelPolicy(BertPolicy):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
# BertForPreTraining
|
||||||
|
class BertForPretrainingPolicy(BertPolicy):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
module_policy = super().module_policy()
|
||||||
|
addon_module = {
|
||||||
|
BertLMPredictionHead:
|
||||||
|
ModulePolicyDescription(attribute_replacement={},
|
||||||
|
param_replacement=[],
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(suffix="decoder",
|
||||||
|
target_module=col_nn.Linear1D_Col,
|
||||||
|
kwargs={"gather_output": True})
|
||||||
|
])
|
||||||
|
}
|
||||||
|
module_policy.update(addon_module)
|
||||||
|
return module_policy
|
||||||
|
|
||||||
|
|
||||||
|
# BertForMaskedLM
|
||||||
class BertForMaskedLMPolicy(BertPolicy):
|
class BertForMaskedLMPolicy(BertPolicy):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def module_policy(self, shard_config: ShardConfig = None):
|
def module_policy(self):
|
||||||
module_policy = super().module_policy(shard_config)
|
module_policy = super().module_policy()
|
||||||
addon_module = {
|
addon_module = {
|
||||||
BertLMPredictionHead:
|
BertLMPredictionHead:
|
||||||
ModulePolicyDescription(attribute_replacement={},
|
ModulePolicyDescription(attribute_replacement={},
|
||||||
|
@ -124,16 +153,41 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||||
# BertLMHeadModel
|
# BertLMHeadModel
|
||||||
class BertLMHeadModelPolicy(BertPolicy):
|
class BertLMHeadModelPolicy(BertPolicy):
|
||||||
|
|
||||||
@staticmethod
|
def __init__(self) -> None:
|
||||||
def argument_policy(config, world_size):
|
super().__init__()
|
||||||
base_argument = BertPolicy.argument_policy(config, world_size)
|
|
||||||
argument = {
|
def module_policy(self):
|
||||||
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
|
module_policy = super().module_policy()
|
||||||
BertPolicy.unembedding,
|
addon_module = {
|
||||||
]),
|
BertLMPredictionHead:
|
||||||
|
ModulePolicyDescription(attribute_replacement={},
|
||||||
|
param_replacement=[],
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(suffix="decoder",
|
||||||
|
target_module=col_nn.Linear1D_Col,
|
||||||
|
kwargs={"gather_output": True})
|
||||||
|
])
|
||||||
}
|
}
|
||||||
argument.update(base_argument)
|
module_policy.update(addon_module)
|
||||||
return argument
|
return module_policy
|
||||||
|
|
||||||
|
|
||||||
|
# BertForNextSentencePrediction
|
||||||
|
class BertForNextSentencePredictionPolicy(BertPolicy):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
# BertForSequenceClassification
|
||||||
|
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
# BertForMultipleChoice
|
||||||
|
class BertForMultipleChoicePolicy(BertPolicy):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -18,10 +18,10 @@ class ShardConfig:
|
||||||
will not calculate the loss and just return the output.
|
will not calculate the loss and just return the output.
|
||||||
gather_output (bool): Whether to gather the output of the model of the last layer
|
gather_output (bool): Whether to gather the output of the model of the last layer
|
||||||
"""
|
"""
|
||||||
data_parallel_size: int
|
|
||||||
tensor_parallel_size: int
|
tensor_parallel_size: int
|
||||||
|
# TODO: add support for tensor parallel
|
||||||
pipeline_parallel_size: int
|
# pipeline_parallel_size: int
|
||||||
|
# data_parallel_size: int
|
||||||
tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
||||||
inference_only: bool = True
|
inference_only: bool = True
|
||||||
gather_output: bool = True
|
gather_output: bool = True
|
||||||
|
|
|
@ -40,6 +40,7 @@ class ModelSharder(object):
|
||||||
Shard the model according to the policy
|
Shard the model according to the policy
|
||||||
"""
|
"""
|
||||||
self.policy.set_model(self.model)
|
self.policy.set_model(self.model)
|
||||||
|
self.policy.set_shard_config(self.shard_config)
|
||||||
self.preprocess()
|
self.preprocess()
|
||||||
self.replace_model_class()
|
self.replace_model_class()
|
||||||
self.replace_module()
|
self.replace_module()
|
||||||
|
@ -57,12 +58,12 @@ class ModelSharder(object):
|
||||||
self.model_config = self.model.config
|
self.model_config = self.model.config
|
||||||
|
|
||||||
def preprocess(self) -> None:
|
def preprocess(self) -> None:
|
||||||
self.model = self.policy.preprocess(self.shard_config)
|
self.model = self.policy.preprocess()
|
||||||
|
|
||||||
def postprocess(self) -> None:
|
def postprocess(self) -> None:
|
||||||
self.model = self.policy.postprocess()
|
self.model = self.policy.postprocess()
|
||||||
|
|
||||||
def replace_model_class(self,) -> None:
|
def replace_model_class(self) -> None:
|
||||||
r"""
|
r"""
|
||||||
Replace the model to policy defined model
|
Replace the model to policy defined model
|
||||||
Mainly modify the forward and backward to fit distributed model
|
Mainly modify the forward and backward to fit distributed model
|
||||||
|
@ -83,14 +84,14 @@ class ModelSharder(object):
|
||||||
getattr(new_model_class, key),
|
getattr(new_model_class, key),
|
||||||
)
|
)
|
||||||
|
|
||||||
def replace_module(self,) -> None:
|
def replace_module(self) -> None:
|
||||||
r"""
|
r"""
|
||||||
Replace the module according to the policy, and replace the module one by one
|
Replace the module according to the policy, and replace the module one by one
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (:class:`torch.nn.Module`): The model to shard
|
model (:class:`torch.nn.Module`): The model to shard
|
||||||
"""
|
"""
|
||||||
module_descriptions = self.policy.module_policy(self.shard_config)
|
module_descriptions = self.policy.module_policy()
|
||||||
for module_description in module_descriptions.items():
|
for module_description in module_descriptions.items():
|
||||||
origin_layer_cls = module_description[0]
|
origin_layer_cls = module_description[0]
|
||||||
attr_replacement = module_description[1].attribute_replacement
|
attr_replacement = module_description[1].attribute_replacement
|
||||||
|
|
|
@ -25,11 +25,7 @@ class ShardFormer:
|
||||||
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||||
shard_config = ShardConfig(
|
shard_config = ShardConfig(
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
data_parallel_size=1,
|
|
||||||
pipeline_parallel_size=1,
|
|
||||||
tensor_parallel_mode='1d',
|
tensor_parallel_mode='1d',
|
||||||
inference_only=True,
|
|
||||||
gather_output=True
|
|
||||||
)
|
)
|
||||||
shard_former = ShardFormer(shard_config=shard_config)
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
shard_former.init_distributed()
|
shard_former.init_distributed()
|
||||||
|
|
|
@ -7,7 +7,6 @@ from transformers import (
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
BertForMaskedLM,
|
BertForMaskedLM,
|
||||||
BertForMultipleChoice,
|
|
||||||
BertForNextSentencePrediction,
|
BertForNextSentencePrediction,
|
||||||
BertForPreTraining,
|
BertForPreTraining,
|
||||||
BertForSequenceClassification,
|
BertForSequenceClassification,
|
||||||
|
@ -36,12 +35,10 @@ def build_model(rank, world_size, model):
|
||||||
org_model.to('cuda')
|
org_model.to('cuda')
|
||||||
# TODO: no need to transfer to cuda
|
# TODO: no need to transfer to cuda
|
||||||
org_model_forshard.to('cuda')
|
org_model_forshard.to('cuda')
|
||||||
shard_config = ShardConfig(tensor_parallel_size=2,
|
shard_config = ShardConfig(
|
||||||
data_parallel_size=1,
|
tensor_parallel_size=2,
|
||||||
pipeline_parallel_size=1,
|
|
||||||
tensor_parallel_mode='1d',
|
tensor_parallel_mode='1d',
|
||||||
inference_only=True,
|
)
|
||||||
gather_output=True)
|
|
||||||
shard_former = ShardFormer(shard_config=shard_config)
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
shard_former.init_distributed()
|
shard_former.init_distributed()
|
||||||
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
|
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
|
||||||
|
|
Loading…
Reference in New Issue