mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
434 lines
17 KiB
434 lines
17 KiB
import warnings |
|
from functools import partial |
|
from typing import Callable, Dict, List |
|
|
|
import torch.nn as nn |
|
from torch import Tensor |
|
from torch.nn import Module |
|
|
|
import colossalai.shardformer.layer as col_nn |
|
|
|
from ..modeling.bloom import ( |
|
BloomPipelineForwards, |
|
build_bloom_alibi_tensor_fn, |
|
get_bloom_sequence_parallel_forward_fn, |
|
get_jit_fused_bloom_attention_forward, |
|
get_jit_fused_bloom_gelu_forward, |
|
get_jit_fused_bloom_mlp_forward, |
|
get_lm_forward_with_dist_cross_entropy, |
|
) |
|
from ..modeling.jit import get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func |
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription |
|
|
|
|
|
class BloomPolicy(Policy): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
def config_sanity_check(self): |
|
pass |
|
|
|
def preprocess(self): |
|
self.tie_weight = self.tie_weight_check() |
|
return self.model |
|
|
|
def module_policy(self): |
|
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomGelu, BloomMLP, BloomModel |
|
|
|
policy = {} |
|
|
|
embedding_cls = None |
|
if self.shard_config.enable_tensor_parallelism: |
|
embedding_cls = col_nn.VocabParallelEmbedding1D |
|
else: |
|
if self.tie_weight: |
|
embedding_cls = col_nn.PaddingEmbedding |
|
|
|
if self.shard_config.enable_fused_normalization: |
|
norm_cls = col_nn.FusedLayerNorm |
|
else: |
|
norm_cls = col_nn.LayerNorm |
|
|
|
sp_mode = self.shard_config.sequence_parallelism_mode or None |
|
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for BLOOM" |
|
if sp_mode == "ring": |
|
warnings.warn( |
|
f"For BLOOM, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" |
|
) |
|
sp_mode = "split_gather" |
|
|
|
overlap = self.shard_config.enable_sequence_overlap |
|
sp_partial_derived = sp_mode == "split_gather" |
|
|
|
if self.shard_config.enable_tensor_parallelism: |
|
assert ( |
|
self.model.config.n_head % self.shard_config.tensor_parallel_size == 0 |
|
), f"The number of attention heads must be divisible by tensor parallel size." |
|
policy[BloomBlock] = ModulePolicyDescription( |
|
attribute_replacement={ |
|
"self_attention.hidden_size": self.model.config.hidden_size |
|
// self.shard_config.tensor_parallel_size, |
|
"self_attention.split_size": self.model.config.hidden_size |
|
// self.shard_config.tensor_parallel_size, |
|
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, |
|
}, |
|
sub_module_replacement=[ |
|
SubModuleReplacementDescription( |
|
suffix="self_attention.query_key_value", |
|
target_module=col_nn.Linear1D_Col, |
|
kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap}, |
|
), |
|
SubModuleReplacementDescription( |
|
suffix="self_attention.dense", |
|
target_module=col_nn.Linear1D_Row, |
|
kwargs={"seq_parallel_mode": sp_mode}, |
|
), |
|
SubModuleReplacementDescription( |
|
suffix="self_attention.attention_dropout", |
|
target_module=col_nn.DropoutForParallelInput, |
|
), |
|
SubModuleReplacementDescription( |
|
suffix="mlp.dense_h_to_4h", |
|
target_module=col_nn.Linear1D_Col, |
|
kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap}, |
|
), |
|
SubModuleReplacementDescription( |
|
suffix="mlp.dense_4h_to_h", |
|
target_module=col_nn.Linear1D_Row, |
|
kwargs={"seq_parallel_mode": sp_mode}, |
|
), |
|
], |
|
) |
|
|
|
policy[BloomModel] = ModulePolicyDescription( |
|
attribute_replacement={ |
|
"num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, |
|
}, |
|
method_replacement={ |
|
"build_alibi_tensor": build_bloom_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group) |
|
}, |
|
) |
|
|
|
if embedding_cls is not None: |
|
self.append_or_create_submodule_replacement( |
|
description=[ |
|
SubModuleReplacementDescription( |
|
suffix="word_embeddings", |
|
target_module=embedding_cls, |
|
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, |
|
), |
|
], |
|
policy=policy, |
|
target_key=BloomModel, |
|
) |
|
|
|
# optimization configuration |
|
# handle bloom model |
|
self.append_or_create_submodule_replacement( |
|
description=[ |
|
SubModuleReplacementDescription( |
|
suffix="ln_f", |
|
target_module=norm_cls, |
|
), |
|
SubModuleReplacementDescription( |
|
suffix="word_embeddings_layernorm", |
|
target_module=norm_cls, |
|
), |
|
], |
|
policy=policy, |
|
target_key=BloomModel, |
|
) |
|
|
|
# handle bloom block |
|
self.append_or_create_submodule_replacement( |
|
description=[ |
|
SubModuleReplacementDescription( |
|
suffix="input_layernorm", |
|
target_module=norm_cls, |
|
kwargs={"sp_partial_derived": sp_partial_derived}, |
|
), |
|
SubModuleReplacementDescription( |
|
suffix="post_attention_layernorm", |
|
target_module=norm_cls, |
|
kwargs={"sp_partial_derived": sp_partial_derived}, |
|
), |
|
], |
|
policy=policy, |
|
target_key=BloomBlock, |
|
) |
|
|
|
if sp_mode == "split_gather": |
|
self.append_or_create_method_replacement( |
|
description={"forward": get_bloom_sequence_parallel_forward_fn(self.shard_config)}, |
|
policy=policy, |
|
target_key=BloomModel, |
|
) |
|
|
|
# enable jit fused operator |
|
if self.shard_config.enable_jit_fused: |
|
self.append_or_create_method_replacement( |
|
description={ |
|
"forward": get_jit_fused_bloom_attention_forward(), |
|
"dropout_add": get_jit_fused_dropout_add_func(), |
|
}, |
|
policy=policy, |
|
target_key=BloomAttention, |
|
) |
|
self.append_or_create_method_replacement( |
|
description={ |
|
"forward": get_jit_fused_bloom_mlp_forward(), |
|
"dropout_add": get_jit_fused_dropout_add_func(), |
|
}, |
|
policy=policy, |
|
target_key=BloomMLP, |
|
) |
|
self.append_or_create_method_replacement( |
|
description={ |
|
"forward": get_jit_fused_bloom_gelu_forward(), |
|
"bloom_gelu_forward": get_jit_fused_gelu_forward_func(), |
|
}, |
|
policy=policy, |
|
target_key=BloomGelu, |
|
) |
|
|
|
return policy |
|
|
|
def postprocess(self): |
|
return self.model |
|
|
|
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: |
|
"""If under pipeline parallel setting, replacing the original forward method of huggingface |
|
to customized forward method, and add this changing to policy.""" |
|
if self.pipeline_stage_manager: |
|
stage_manager = self.pipeline_stage_manager |
|
if self.model.__class__.__name__ == "BloomModel": |
|
module = self.model |
|
else: |
|
module = self.model.transformer |
|
|
|
layers_per_stage = stage_manager.distribute_layers(len(module.h)) |
|
stage_index = stage_manager.get_stage_index(layers_per_stage) |
|
method_replacement = { |
|
"forward": partial( |
|
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config |
|
) |
|
} |
|
self.append_or_create_method_replacement( |
|
description=method_replacement, policy=policy, target_key=model_cls |
|
) |
|
return |
|
|
|
def get_held_layers(self) -> List[Module]: |
|
"""Get pipeline layers for current stage.""" |
|
assert self.pipeline_stage_manager is not None |
|
|
|
if self.model.__class__.__name__ == "BloomModel": |
|
module = self.model |
|
else: |
|
module = self.model.transformer |
|
stage_manager = self.pipeline_stage_manager |
|
|
|
held_layers = [] |
|
layers_per_stage = stage_manager.distribute_layers(len(module.h)) |
|
if stage_manager.is_first_stage(): |
|
held_layers.append(module.word_embeddings) |
|
held_layers.append(module.word_embeddings_layernorm) |
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) |
|
held_layers.extend(module.h[start_idx:end_idx]) |
|
if stage_manager.is_last_stage(): |
|
held_layers.append(module.ln_f) |
|
|
|
return held_layers |
|
|
|
|
|
class BloomModelPolicy(BloomPolicy): |
|
def module_policy(self): |
|
policy = super().module_policy() |
|
from transformers.models.bloom.modeling_bloom import BloomModel |
|
|
|
if self.pipeline_stage_manager: |
|
self.set_pipeline_forward( |
|
model_cls=BloomModel, new_forward=BloomPipelineForwards.bloom_model_forward, policy=policy |
|
) |
|
return policy |
|
|
|
def get_held_layers(self) -> List[Module]: |
|
""" |
|
get pipeline layers for current stage |
|
""" |
|
held_layers = super().get_held_layers() |
|
return held_layers |
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]: |
|
"""no shared params in bloom model""" |
|
return [] |
|
|
|
|
|
class BloomForCausalLMPolicy(BloomPolicy): |
|
def module_policy(self): |
|
from transformers.models.bloom.modeling_bloom import BloomForCausalLM |
|
|
|
policy = super().module_policy() |
|
|
|
# handle tensor parallelism |
|
if self.shard_config.enable_tensor_parallelism: |
|
self.append_or_create_submodule_replacement( |
|
description=SubModuleReplacementDescription( |
|
suffix="lm_head", |
|
target_module=col_nn.VocabParallelLMHead1D, |
|
kwargs=dict( |
|
gather_output=not self.shard_config.parallel_output, |
|
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, |
|
), |
|
), |
|
policy=policy, |
|
target_key=BloomForCausalLM, |
|
) |
|
if self.shard_config.parallel_output: |
|
method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} |
|
self.append_or_create_method_replacement( |
|
description=method_replacement, policy=policy, target_key=BloomForCausalLM |
|
) |
|
else: |
|
self.append_or_create_submodule_replacement( |
|
description=SubModuleReplacementDescription( |
|
suffix="lm_head", |
|
target_module=col_nn.PaddingLMHead, |
|
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by), |
|
), |
|
policy=policy, |
|
target_key=BloomForCausalLM, |
|
) |
|
if self.pipeline_stage_manager: |
|
self.set_pipeline_forward( |
|
model_cls=BloomForCausalLM, new_forward=BloomPipelineForwards.bloom_for_causal_lm_forward, policy=policy |
|
) |
|
return policy |
|
|
|
def get_held_layers(self) -> List[Module]: |
|
"""Get pipeline layers for current stage.""" |
|
stage_manager = self.pipeline_stage_manager |
|
held_layers = super().get_held_layers() |
|
if stage_manager.is_last_stage(): |
|
held_layers.append(self.model.lm_head) |
|
return held_layers |
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]: |
|
bloom_model = self.model |
|
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: |
|
if id(bloom_model.transformer.word_embeddings.weight) == id(bloom_model.lm_head.weight): |
|
# tie weights |
|
return [ |
|
{ |
|
0: bloom_model.transformer.word_embeddings.weight, |
|
self.pipeline_stage_manager.num_stages - 1: bloom_model.lm_head.weight, |
|
} |
|
] |
|
return [] |
|
|
|
|
|
class BloomForSequenceClassificationPolicy(BloomPolicy): |
|
def module_policy(self): |
|
from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification |
|
|
|
policy = super().module_policy() |
|
|
|
# handle tensor parallelism |
|
if self.shard_config.enable_tensor_parallelism: |
|
self.append_or_create_submodule_replacement( |
|
description=SubModuleReplacementDescription( |
|
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) |
|
), |
|
policy=policy, |
|
target_key=BloomForSequenceClassification, |
|
) |
|
if self.pipeline_stage_manager: |
|
self.set_pipeline_forward( |
|
model_cls=BloomForSequenceClassification, |
|
new_forward=BloomPipelineForwards.bloom_for_sequence_classification_forward, |
|
policy=policy, |
|
) |
|
return policy |
|
|
|
def get_held_layers(self) -> List[Module]: |
|
"""Get pipeline layers for current stage.""" |
|
stage_manager = self.pipeline_stage_manager |
|
held_layers = super().get_held_layers() |
|
if stage_manager.is_last_stage(): |
|
held_layers.append(self.model.score) |
|
return held_layers |
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]: |
|
"""No shared params in bloom for sequence classification model""" |
|
return [] |
|
|
|
|
|
class BloomForTokenClassificationPolicy(BloomPolicy): |
|
def module_policy(self): |
|
from transformers.models.bloom.modeling_bloom import BloomForTokenClassification |
|
|
|
policy = super().module_policy() |
|
|
|
# handle tensor parallelism |
|
if self.shard_config.enable_tensor_parallelism: |
|
self.append_or_create_submodule_replacement( |
|
description=[ |
|
SubModuleReplacementDescription( |
|
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) |
|
), |
|
SubModuleReplacementDescription( |
|
suffix="dropout", |
|
target_module=col_nn.DropoutForReplicatedInput, |
|
), |
|
], |
|
policy=policy, |
|
target_key=BloomForTokenClassification, |
|
) |
|
if self.pipeline_stage_manager: |
|
self.set_pipeline_forward( |
|
model_cls=BloomForTokenClassification, |
|
new_forward=BloomPipelineForwards.bloom_for_token_classification_forward, |
|
policy=policy, |
|
) |
|
|
|
return policy |
|
|
|
def get_held_layers(self) -> List[Module]: |
|
"""Get pipeline layers for current stage.""" |
|
stage_manager = self.pipeline_stage_manager |
|
held_layers = super().get_held_layers() |
|
if stage_manager.is_last_stage(): |
|
held_layers.append(self.model.dropout) |
|
held_layers.append(self.model.classifier) |
|
return held_layers |
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]: |
|
"""No shared params in bloom for token classification model""" |
|
return [] |
|
|
|
|
|
class BloomForQuestionAnsweringPolicy(BloomPolicy): |
|
# No head sharding as the output features is only 2 |
|
def module_policy(self): |
|
from transformers.models.bloom.modeling_bloom import BloomForQuestionAnswering |
|
|
|
policy = super().module_policy() |
|
if self.pipeline_stage_manager: |
|
self.set_pipeline_forward( |
|
model_cls=BloomForQuestionAnswering, |
|
new_forward=BloomPipelineForwards.bloom_for_question_answering_forward, |
|
policy=policy, |
|
) |
|
return policy |
|
|
|
def get_held_layers(self) -> List[Module]: |
|
"""Get pipeline layers for current stage.""" |
|
held_layers = super().get_held_layers() |
|
stage_manager = self.pipeline_stage_manager |
|
if stage_manager.is_last_stage(): |
|
held_layers.append(self.model.qa_outputs) |
|
return held_layers |
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]: |
|
"""No shared params in bloom for question answering model""" |
|
return []
|
|
|