2023-07-07 07:41:00 +00:00
|
|
|
from functools import partial
|
2023-07-21 02:46:39 +00:00
|
|
|
from typing import Callable, Dict, List
|
2023-07-07 07:41:00 +00:00
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
import torch.nn as nn
|
2023-07-07 07:41:00 +00:00
|
|
|
from torch import Tensor
|
2023-07-21 02:46:39 +00:00
|
|
|
from torch.nn import Module
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-06-20 03:45:16 +00:00
|
|
|
import colossalai.shardformer.layer as col_nn
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-08-07 08:41:07 +00:00
|
|
|
from .._utils import getattr_, setattr_
|
|
|
|
from ..modeling.bert import (
|
|
|
|
BertPipelineForwards,
|
2023-08-18 10:04:55 +00:00
|
|
|
bert_sequence_parallel_forward_fn,
|
2023-08-07 08:41:07 +00:00
|
|
|
get_bert_flash_attention_forward,
|
|
|
|
get_jit_fused_bert_output_forward,
|
|
|
|
get_jit_fused_bert_self_output_forward,
|
|
|
|
)
|
|
|
|
from ..modeling.jit import get_jit_fused_dropout_add_func
|
2023-07-05 07:13:00 +00:00
|
|
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
2023-06-15 09:55:42 +00:00
|
|
|
|
2023-06-30 02:56:29 +00:00
|
|
|
__all__ = [
|
2023-07-17 08:12:20 +00:00
|
|
|
'BertPolicy', 'BertModelPolicy', 'BertForPreTrainingPolicy', 'BertLMdHeadModelPolicy', 'BertForMaskedLMPolicy',
|
2023-06-30 02:56:29 +00:00
|
|
|
'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy',
|
2023-07-17 08:12:20 +00:00
|
|
|
'BertForMultipleChoicePolicy', 'BertForQuestionAnsweringPolicy'
|
2023-06-30 02:56:29 +00:00
|
|
|
]
|
|
|
|
|
2023-06-15 09:55:42 +00:00
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
class BertPolicy(Policy):
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-06-30 01:32:37 +00:00
|
|
|
def config_sanity_check(self):
|
|
|
|
pass
|
|
|
|
|
2023-06-19 02:47:16 +00:00
|
|
|
def preprocess(self):
|
2023-06-15 09:55:42 +00:00
|
|
|
# reshape the embedding layer
|
|
|
|
r"""
|
|
|
|
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
|
|
|
"""
|
|
|
|
# TODO:
|
2023-07-10 02:48:53 +00:00
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
vocab_size = self.model.config.vocab_size
|
|
|
|
world_size = self.shard_config.tensor_parallel_size
|
|
|
|
if vocab_size % world_size != 0:
|
|
|
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
|
|
|
self.model.resize_token_embeddings(new_vocab_size)
|
2023-06-15 09:55:42 +00:00
|
|
|
return self.model
|
|
|
|
|
2023-06-19 02:47:16 +00:00
|
|
|
def module_policy(self):
|
2023-08-07 08:41:07 +00:00
|
|
|
from transformers.models.bert.modeling_bert import (
|
|
|
|
BertEmbeddings,
|
|
|
|
BertLayer,
|
2023-08-18 10:04:55 +00:00
|
|
|
BertModel,
|
2023-08-07 08:41:07 +00:00
|
|
|
BertOutput,
|
|
|
|
BertSelfAttention,
|
|
|
|
BertSelfOutput,
|
|
|
|
)
|
2023-06-30 02:56:29 +00:00
|
|
|
|
2023-07-04 01:57:03 +00:00
|
|
|
policy = {}
|
2023-08-18 10:04:55 +00:00
|
|
|
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
2023-08-28 09:16:40 +00:00
|
|
|
overlap = self.shard_config.enable_sequence_overlap
|
2023-07-04 01:57:03 +00:00
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
|
|
|
|
"attention.self.all_head_size":
|
|
|
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
|
|
|
"crossattention.self.all_head_size":
|
|
|
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
|
|
|
"attention.self.num_attention_heads":
|
|
|
|
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
|
|
|
"crossattention.self.num_attention_heads":
|
|
|
|
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
|
|
|
},
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.self.query",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2023-08-28 09:16:40 +00:00
|
|
|
kwargs={
|
|
|
|
"seq_parallel": use_sequence_parallel,
|
|
|
|
"overlap": overlap
|
|
|
|
},
|
2023-07-04 01:57:03 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.self.key",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2023-08-28 09:16:40 +00:00
|
|
|
kwargs={
|
|
|
|
"seq_parallel": use_sequence_parallel,
|
|
|
|
"overlap": overlap
|
|
|
|
},
|
2023-07-04 01:57:03 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.self.value",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2023-08-28 09:16:40 +00:00
|
|
|
kwargs={
|
|
|
|
"seq_parallel": use_sequence_parallel,
|
|
|
|
"overlap": overlap
|
|
|
|
},
|
2023-07-04 01:57:03 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.self.dropout",
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.output.dense",
|
|
|
|
target_module=col_nn.Linear1D_Row,
|
2023-08-18 10:04:55 +00:00
|
|
|
kwargs={"seq_parallel": use_sequence_parallel},
|
2023-07-04 01:57:03 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.output.dropout",
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="intermediate.dense",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2023-08-28 09:16:40 +00:00
|
|
|
kwargs={
|
|
|
|
"seq_parallel": use_sequence_parallel,
|
|
|
|
"overlap": overlap
|
|
|
|
},
|
2023-07-04 01:57:03 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="output.dense",
|
|
|
|
target_module=col_nn.Linear1D_Row,
|
2023-08-18 10:04:55 +00:00
|
|
|
kwargs={"seq_parallel": use_sequence_parallel},
|
2023-07-04 01:57:03 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="output.dropout",
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
)
|
|
|
|
])
|
|
|
|
|
|
|
|
policy[BertEmbeddings] = ModulePolicyDescription(sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="word_embeddings",
|
|
|
|
target_module=col_nn.VocabParallelEmbedding1D,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="dropout",
|
|
|
|
target_module=col_nn.DropoutForReplicatedInput,
|
|
|
|
)
|
|
|
|
])
|
2023-06-15 09:56:51 +00:00
|
|
|
|
2023-08-18 10:04:55 +00:00
|
|
|
if use_sequence_parallel:
|
|
|
|
self.append_or_create_method_replacement(
|
|
|
|
description={'forward': bert_sequence_parallel_forward_fn(self.shard_config)},
|
|
|
|
policy=policy,
|
|
|
|
target_key=BertModel)
|
|
|
|
|
2023-06-30 01:32:37 +00:00
|
|
|
# optimization configuration
|
|
|
|
if self.shard_config.enable_fused_normalization:
|
2023-07-04 01:57:03 +00:00
|
|
|
# Handle bert layer
|
|
|
|
self.append_or_create_submodule_replacement(description=[
|
2023-06-23 10:00:22 +00:00
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.output.LayerNorm",
|
2023-06-26 10:05:00 +00:00
|
|
|
target_module=col_nn.FusedLayerNorm,
|
2023-07-04 01:57:03 +00:00
|
|
|
),
|
2023-06-23 10:00:22 +00:00
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="output.LayerNorm",
|
2023-06-26 10:05:00 +00:00
|
|
|
target_module=col_nn.FusedLayerNorm,
|
2023-07-04 01:57:03 +00:00
|
|
|
)
|
|
|
|
],
|
|
|
|
policy=policy,
|
|
|
|
target_key=BertLayer)
|
|
|
|
# handle embedding layer
|
|
|
|
self.append_or_create_submodule_replacement(
|
|
|
|
description=[SubModuleReplacementDescription(
|
2023-06-23 10:00:22 +00:00
|
|
|
suffix="LayerNorm",
|
2023-06-26 10:05:00 +00:00
|
|
|
target_module=col_nn.FusedLayerNorm,
|
2023-07-04 01:57:03 +00:00
|
|
|
)],
|
|
|
|
policy=policy,
|
|
|
|
target_key=BertEmbeddings)
|
2023-07-10 05:58:58 +00:00
|
|
|
|
2023-08-07 08:41:07 +00:00
|
|
|
# use flash attention
|
|
|
|
if self.shard_config.enable_flash_attention:
|
2023-08-22 15:59:31 +00:00
|
|
|
self.append_or_create_method_replacement(description={
|
2023-08-07 08:41:07 +00:00
|
|
|
'forward': get_bert_flash_attention_forward(),
|
2023-08-22 15:59:31 +00:00
|
|
|
},
|
|
|
|
policy=policy,
|
|
|
|
target_key=BertSelfAttention)
|
2023-08-07 08:41:07 +00:00
|
|
|
|
|
|
|
# use jit operator
|
|
|
|
if self.shard_config.enable_jit_fused:
|
2023-08-22 15:59:31 +00:00
|
|
|
self.append_or_create_method_replacement(description={
|
2023-08-07 08:41:07 +00:00
|
|
|
'forward': get_jit_fused_bert_self_output_forward(),
|
|
|
|
'dropout_add': get_jit_fused_dropout_add_func(),
|
2023-08-22 15:59:31 +00:00
|
|
|
},
|
|
|
|
policy=policy,
|
|
|
|
target_key=BertSelfOutput)
|
|
|
|
self.append_or_create_method_replacement(description={
|
2023-08-07 08:41:07 +00:00
|
|
|
'forward': get_jit_fused_bert_output_forward(),
|
|
|
|
'dropout_add': get_jit_fused_dropout_add_func(),
|
2023-08-22 15:59:31 +00:00
|
|
|
},
|
|
|
|
policy=policy,
|
|
|
|
target_key=BertOutput)
|
2023-08-07 08:41:07 +00:00
|
|
|
|
2023-07-04 01:57:03 +00:00
|
|
|
return policy
|
|
|
|
|
|
|
|
def add_lm_head_policy(self, base_policy):
|
|
|
|
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
|
|
|
|
|
|
|
# optimize for tensor parallelism
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
|
|
|
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
|
|
|
|
policy=base_policy,
|
|
|
|
target_key=BertLMPredictionHead)
|
|
|
|
|
|
|
|
# optimize with fused normalization
|
|
|
|
if self.shard_config.enable_fused_normalization:
|
|
|
|
# Handle bert lm prediction head
|
|
|
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
|
|
|
suffix="transform.LayerNorm",
|
|
|
|
target_module=col_nn.FusedLayerNorm,
|
|
|
|
),
|
|
|
|
policy=base_policy,
|
|
|
|
target_key=BertLMPredictionHead)
|
2023-06-23 10:00:22 +00:00
|
|
|
return base_policy
|
|
|
|
|
2023-07-20 02:39:06 +00:00
|
|
|
def add_lm_prediction_policy(self, base_policy):
|
|
|
|
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
|
|
|
method_replacement = {
|
|
|
|
'_save_to_state_dict': col_nn.ParallelModule._save_to_state_dict,
|
|
|
|
'_load_from_state_dict': col_nn.ParallelModule._load_from_state_dict,
|
|
|
|
}
|
|
|
|
self.append_or_create_method_replacement(description=method_replacement,
|
|
|
|
policy=base_policy,
|
|
|
|
target_key=BertLMPredictionHead)
|
|
|
|
return base_policy
|
|
|
|
|
2023-06-15 09:55:42 +00:00
|
|
|
def postprocess(self):
|
|
|
|
return self.model
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-07-17 08:12:20 +00:00
|
|
|
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
|
|
|
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
|
|
|
to customized forward method, and add this changing to policy."""
|
|
|
|
if self.pipeline_stage_manager:
|
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
if self.model.__class__.__name__ == "BertModel":
|
|
|
|
module = self.model
|
|
|
|
else:
|
|
|
|
module = self.model.bert
|
|
|
|
|
|
|
|
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
|
|
|
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
2023-08-18 10:04:55 +00:00
|
|
|
method_replacement = {
|
|
|
|
'forward':
|
|
|
|
partial(new_forward,
|
|
|
|
stage_manager=stage_manager,
|
|
|
|
stage_index=stage_index,
|
|
|
|
shard_config=self.shard_config)
|
|
|
|
}
|
2023-07-17 08:12:20 +00:00
|
|
|
self.append_or_create_method_replacement(description=method_replacement,
|
|
|
|
policy=policy,
|
|
|
|
target_key=model_cls)
|
|
|
|
|
|
|
|
return
|
|
|
|
|
2023-07-21 02:46:39 +00:00
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
|
|
"""Get pipeline layers for current stage."""
|
|
|
|
assert self.pipeline_stage_manager is not None
|
|
|
|
|
|
|
|
if self.model.__class__.__name__ == 'BertModel':
|
|
|
|
module = self.model
|
|
|
|
else:
|
|
|
|
module = self.model.bert
|
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
|
|
|
|
held_layers = []
|
|
|
|
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
|
|
|
if stage_manager.is_first_stage():
|
|
|
|
held_layers.append(module.embeddings)
|
|
|
|
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
|
|
|
held_layers.extend(module.encoder.layer[start_idx:end_idx])
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
held_layers.append(module.pooler)
|
|
|
|
|
|
|
|
return held_layers
|
|
|
|
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-06-19 02:47:16 +00:00
|
|
|
# BertModel
|
|
|
|
class BertModelPolicy(BertPolicy):
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
2023-07-10 05:58:58 +00:00
|
|
|
def module_policy(self):
|
2023-07-17 08:12:20 +00:00
|
|
|
policy = super().module_policy()
|
2023-07-10 05:58:58 +00:00
|
|
|
from transformers.models.bert.modeling_bert import BertModel
|
2023-07-21 02:46:39 +00:00
|
|
|
if self.pipeline_stage_manager:
|
|
|
|
self.set_pipeline_forward(model_cls=BertModel,
|
|
|
|
new_forward=BertPipelineForwards.bert_model_forward,
|
|
|
|
policy=policy)
|
2023-07-17 08:12:20 +00:00
|
|
|
return policy
|
2023-07-10 05:58:58 +00:00
|
|
|
|
2023-07-07 07:41:00 +00:00
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
|
|
"""Get pipeline layers for current stage."""
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers = super().get_held_layers()
|
2023-07-07 07:41:00 +00:00
|
|
|
return held_layers
|
|
|
|
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
|
|
"""No shared params in bert model"""
|
|
|
|
return []
|
|
|
|
|
2023-06-19 02:47:16 +00:00
|
|
|
|
|
|
|
# BertForPreTraining
|
2023-07-07 07:41:00 +00:00
|
|
|
class BertForPreTrainingPolicy(BertPolicy):
|
2023-06-19 02:47:16 +00:00
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def module_policy(self):
|
2023-07-17 08:12:20 +00:00
|
|
|
policy = super().module_policy()
|
|
|
|
policy = self.add_lm_head_policy(policy)
|
2023-07-20 02:39:06 +00:00
|
|
|
policy = self.add_lm_prediction_policy(policy)
|
2023-07-17 08:12:20 +00:00
|
|
|
from transformers.models.bert.modeling_bert import BertForPreTraining
|
2023-07-21 02:46:39 +00:00
|
|
|
if self.pipeline_stage_manager:
|
|
|
|
self.set_pipeline_forward(model_cls=BertForPreTraining,
|
|
|
|
new_forward=BertPipelineForwards.bert_for_pretraining_forward,
|
|
|
|
policy=policy)
|
2023-07-17 08:12:20 +00:00
|
|
|
return policy
|
2023-06-19 02:47:16 +00:00
|
|
|
|
2023-07-07 07:41:00 +00:00
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
|
|
"""Get pipeline layers for current stage"""
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers = super().get_held_layers()
|
2023-07-07 07:41:00 +00:00
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
if stage_manager.is_last_stage():
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers.append(self.model.cls)
|
2023-07-07 07:41:00 +00:00
|
|
|
|
|
|
|
return held_layers
|
|
|
|
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
2023-07-17 08:12:20 +00:00
|
|
|
model = self.model
|
2023-07-21 02:46:39 +00:00
|
|
|
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
2023-07-17 08:12:20 +00:00
|
|
|
if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight):
|
2023-07-20 02:39:06 +00:00
|
|
|
# tie weights
|
2023-07-17 08:12:20 +00:00
|
|
|
return [{
|
|
|
|
0: model.bert.embeddings.word_embeddings.weight,
|
|
|
|
self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight
|
|
|
|
}]
|
2023-07-07 07:41:00 +00:00
|
|
|
return []
|
|
|
|
|
2023-06-19 02:47:16 +00:00
|
|
|
|
2023-06-22 02:33:06 +00:00
|
|
|
# BertLMHeadModel
|
|
|
|
class BertLMHeadModelPolicy(BertPolicy):
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-06-15 09:55:42 +00:00
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
2023-06-12 08:52:18 +00:00
|
|
|
|
2023-06-19 02:47:16 +00:00
|
|
|
def module_policy(self):
|
2023-07-17 08:12:20 +00:00
|
|
|
policy = super().module_policy()
|
|
|
|
policy = self.add_lm_head_policy(policy)
|
2023-07-20 02:39:06 +00:00
|
|
|
policy = self.add_lm_prediction_policy(policy)
|
2023-07-17 08:12:20 +00:00
|
|
|
from transformers.models.bert.modeling_bert import BertLMHeadModel
|
2023-07-21 02:46:39 +00:00
|
|
|
if self.pipeline_stage_manager:
|
|
|
|
self.set_pipeline_forward(model_cls=BertLMHeadModel,
|
|
|
|
new_forward=BertPipelineForwards.bert_lm_head_model_forward,
|
|
|
|
policy=policy)
|
2023-07-17 08:12:20 +00:00
|
|
|
return policy
|
2023-06-16 08:12:27 +00:00
|
|
|
|
2023-07-07 07:41:00 +00:00
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
|
|
"""
|
|
|
|
get pipeline layers for current stage
|
|
|
|
"""
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers = super().get_held_layers()
|
2023-07-07 07:41:00 +00:00
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
if stage_manager.is_last_stage():
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers.append(self.model.cls)
|
2023-07-07 07:41:00 +00:00
|
|
|
return held_layers
|
|
|
|
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
2023-07-17 08:12:20 +00:00
|
|
|
bert_model = self.model.bert
|
2023-07-21 02:46:39 +00:00
|
|
|
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
2023-07-17 08:12:20 +00:00
|
|
|
if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
|
2023-07-20 02:39:06 +00:00
|
|
|
# tie weights
|
2023-07-17 08:12:20 +00:00
|
|
|
return [{
|
|
|
|
0: bert_model.embeddings.word_embeddings.weight,
|
|
|
|
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
|
|
|
|
}]
|
2023-07-07 07:41:00 +00:00
|
|
|
return []
|
|
|
|
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-06-22 02:33:06 +00:00
|
|
|
# BertForMaskedLM
|
|
|
|
class BertForMaskedLMPolicy(BertPolicy):
|
2023-06-15 09:56:51 +00:00
|
|
|
|
2023-06-19 02:47:16 +00:00
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def module_policy(self):
|
2023-07-17 08:12:20 +00:00
|
|
|
policy = super().module_policy()
|
|
|
|
policy = self.add_lm_head_policy(policy)
|
2023-08-01 02:35:17 +00:00
|
|
|
policy = self.add_lm_prediction_policy(policy)
|
2023-07-17 08:12:20 +00:00
|
|
|
from transformers.models.bert.modeling_bert import BertForMaskedLM
|
2023-07-21 02:46:39 +00:00
|
|
|
if self.pipeline_stage_manager:
|
|
|
|
self.set_pipeline_forward(model_cls=BertForMaskedLM,
|
|
|
|
new_forward=BertPipelineForwards.bert_for_masked_lm_forward,
|
|
|
|
policy=policy)
|
2023-07-17 08:12:20 +00:00
|
|
|
return policy
|
|
|
|
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
|
|
"""
|
|
|
|
get pipeline layers for current stage
|
|
|
|
"""
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers = super().get_held_layers()
|
2023-07-17 08:12:20 +00:00
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
if stage_manager.is_last_stage():
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers.append(self.model.cls)
|
2023-07-17 08:12:20 +00:00
|
|
|
return held_layers
|
|
|
|
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
|
|
bert_model = self.model.bert
|
2023-07-21 02:46:39 +00:00
|
|
|
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
2023-07-17 08:12:20 +00:00
|
|
|
if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
|
2023-07-20 02:39:06 +00:00
|
|
|
# tie weights
|
2023-07-17 08:12:20 +00:00
|
|
|
return [{
|
|
|
|
0: bert_model.embeddings.word_embeddings.weight,
|
|
|
|
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
|
|
|
|
}]
|
|
|
|
return []
|
2023-06-19 02:47:16 +00:00
|
|
|
|
|
|
|
|
2023-06-22 02:33:06 +00:00
|
|
|
# BertForSequenceClassification
|
|
|
|
class BertForSequenceClassificationPolicy(BertPolicy):
|
2023-06-19 02:47:16 +00:00
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
2023-06-23 10:00:22 +00:00
|
|
|
def module_policy(self):
|
2023-06-30 02:56:29 +00:00
|
|
|
from transformers.models.bert.modeling_bert import BertForSequenceClassification
|
|
|
|
|
2023-07-17 08:12:20 +00:00
|
|
|
policy = super().module_policy()
|
2023-07-04 01:57:03 +00:00
|
|
|
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
addon_module = {
|
|
|
|
BertForSequenceClassification:
|
|
|
|
ModulePolicyDescription(sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="dropout",
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
)
|
|
|
|
])
|
|
|
|
}
|
2023-07-17 08:12:20 +00:00
|
|
|
policy.update(addon_module)
|
2023-07-21 02:46:39 +00:00
|
|
|
if self.pipeline_stage_manager:
|
|
|
|
self.set_pipeline_forward(model_cls=BertForSequenceClassification,
|
|
|
|
new_forward=BertPipelineForwards.bert_for_sequence_classification_forward,
|
|
|
|
policy=policy)
|
2023-07-17 08:12:20 +00:00
|
|
|
|
|
|
|
return policy
|
|
|
|
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
|
|
"""
|
|
|
|
get pipeline layers for current stage
|
|
|
|
"""
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers = super().get_held_layers()
|
2023-07-17 08:12:20 +00:00
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
if stage_manager.is_last_stage():
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers.append(self.model.dropout)
|
|
|
|
held_layers.append(self.model.classifier)
|
2023-07-17 08:12:20 +00:00
|
|
|
return held_layers
|
|
|
|
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
|
|
# no shared params for sequence classification model
|
|
|
|
return []
|
2023-06-23 10:00:22 +00:00
|
|
|
|
2023-06-19 02:47:16 +00:00
|
|
|
|
2023-06-22 02:33:06 +00:00
|
|
|
# BertForTokenClassification
|
|
|
|
class BertForTokenClassificationPolicy(BertPolicy):
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
2023-06-23 10:00:22 +00:00
|
|
|
def module_policy(self):
|
2023-06-30 02:56:29 +00:00
|
|
|
from transformers.models.bert.modeling_bert import BertForTokenClassification
|
|
|
|
|
2023-07-17 08:12:20 +00:00
|
|
|
policy = super().module_policy()
|
2023-07-04 01:57:03 +00:00
|
|
|
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
addon_module = {
|
|
|
|
BertForTokenClassification:
|
|
|
|
ModulePolicyDescription(sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="dropout",
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
)
|
|
|
|
])
|
|
|
|
}
|
2023-07-17 08:12:20 +00:00
|
|
|
policy.update(addon_module)
|
2023-07-21 02:46:39 +00:00
|
|
|
if self.pipeline_stage_manager:
|
|
|
|
self.set_pipeline_forward(model_cls=BertForTokenClassification,
|
|
|
|
new_forward=BertPipelineForwards.bert_for_token_classification_forward,
|
|
|
|
policy=policy)
|
2023-07-17 08:12:20 +00:00
|
|
|
|
|
|
|
return policy
|
|
|
|
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
|
|
"""
|
|
|
|
get pipeline layers for current stage
|
|
|
|
"""
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers = super().get_held_layers()
|
2023-07-17 08:12:20 +00:00
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
if stage_manager.is_last_stage():
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers.append(self.model.dropout)
|
|
|
|
held_layers.append(self.model.classifier)
|
2023-07-17 08:12:20 +00:00
|
|
|
return held_layers
|
|
|
|
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
|
|
# no shared params for sequence classification model
|
|
|
|
return []
|
2023-06-23 10:00:22 +00:00
|
|
|
|
2023-06-22 02:33:06 +00:00
|
|
|
|
|
|
|
# BertForNextSentencePrediction
|
|
|
|
class BertForNextSentencePredictionPolicy(BertPolicy):
|
2023-06-19 02:47:16 +00:00
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
2023-07-17 08:12:20 +00:00
|
|
|
def module_policy(self):
|
|
|
|
policy = super().module_policy()
|
|
|
|
from transformers.models.bert.modeling_bert import BertForNextSentencePrediction
|
2023-07-21 02:46:39 +00:00
|
|
|
if self.pipeline_stage_manager:
|
|
|
|
self.set_pipeline_forward(model_cls=BertForNextSentencePrediction,
|
|
|
|
new_forward=BertPipelineForwards.bert_for_next_sentence_prediction_forward,
|
|
|
|
policy=policy)
|
2023-07-17 08:12:20 +00:00
|
|
|
|
|
|
|
return policy
|
|
|
|
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
|
|
"""
|
|
|
|
get pipeline layers for current stage
|
|
|
|
"""
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers = super().get_held_layers()
|
2023-07-17 08:12:20 +00:00
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
if stage_manager.is_last_stage():
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers.append(self.model.cls)
|
2023-07-17 08:12:20 +00:00
|
|
|
return held_layers
|
|
|
|
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
|
|
# no shared params for sequence classification model
|
|
|
|
return []
|
|
|
|
|
2023-06-19 02:47:16 +00:00
|
|
|
|
|
|
|
# BertForMultipleChoice
|
|
|
|
class BertForMultipleChoicePolicy(BertPolicy):
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-06-15 09:55:42 +00:00
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
2023-06-23 10:00:22 +00:00
|
|
|
|
|
|
|
def module_policy(self):
|
2023-06-30 02:56:29 +00:00
|
|
|
from transformers.models.bert.modeling_bert import BertForMultipleChoice
|
|
|
|
|
2023-07-17 08:12:20 +00:00
|
|
|
policy = super().module_policy()
|
2023-07-04 01:57:03 +00:00
|
|
|
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
addon_module = {
|
|
|
|
BertForMultipleChoice:
|
|
|
|
ModulePolicyDescription(sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="dropout",
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
)
|
|
|
|
])
|
|
|
|
}
|
2023-07-17 08:12:20 +00:00
|
|
|
policy.update(addon_module)
|
2023-07-21 02:46:39 +00:00
|
|
|
if self.pipeline_stage_manager:
|
|
|
|
self.set_pipeline_forward(model_cls=BertForMultipleChoice,
|
|
|
|
new_forward=BertPipelineForwards.bert_for_multiple_choice_forward,
|
|
|
|
policy=policy)
|
2023-07-17 08:12:20 +00:00
|
|
|
|
|
|
|
return policy
|
|
|
|
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
|
|
"""
|
|
|
|
get pipeline layers for current stage
|
|
|
|
"""
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers = super().get_held_layers()
|
2023-07-17 08:12:20 +00:00
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
if stage_manager.is_last_stage():
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers.append(self.model.dropout)
|
|
|
|
held_layers.append(self.model.classifier)
|
2023-07-17 08:12:20 +00:00
|
|
|
return held_layers
|
|
|
|
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
|
|
# no shared params for sequence classification model
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
class BertForQuestionAnsweringPolicy(BertPolicy):
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def module_policy(self):
|
|
|
|
from transformers.models.bert.modeling_bert import BertForQuestionAnswering
|
|
|
|
policy = super().module_policy()
|
2023-07-21 02:46:39 +00:00
|
|
|
if self.pipeline_stage_manager:
|
|
|
|
self.set_pipeline_forward(model_cls=BertForQuestionAnswering,
|
|
|
|
new_forward=BertPipelineForwards.bert_for_question_answering_forward,
|
|
|
|
policy=policy)
|
2023-07-17 08:12:20 +00:00
|
|
|
|
|
|
|
return policy
|
|
|
|
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
|
|
"""
|
|
|
|
get pipeline layers for current stage
|
|
|
|
"""
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers = super().get_held_layers()
|
2023-07-17 08:12:20 +00:00
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
if stage_manager.is_last_stage():
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers.append(self.model.qa_outputs)
|
2023-07-17 08:12:20 +00:00
|
|
|
return held_layers
|
|
|
|
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
|
|
# no shared params for sequence classification model
|
|
|
|
return []
|