mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
408 lines
17 KiB
408 lines
17 KiB
from functools import partial |
|
from typing import Callable, Dict, List, Union |
|
|
|
import torch.nn as nn |
|
from torch import Tensor |
|
from torch.nn import Module |
|
|
|
from colossalai.shardformer.layer import ( |
|
FusedRMSNorm, |
|
Linear1D_Col, |
|
Linear1D_Row, |
|
PaddingEmbedding, |
|
RMSNorm, |
|
VocabParallelEmbedding1D, |
|
) |
|
|
|
from ..modeling.qwen2 import ( |
|
Qwen2PipelineForwards, |
|
get_lm_forward_with_dist_cross_entropy, |
|
get_qwen2_flash_attention_forward, |
|
get_qwen2_model_forward_for_flash_attn, |
|
) |
|
|
|
try: |
|
from transformers.models.qwen2.modeling_qwen2 import ( |
|
Qwen2Attention, |
|
Qwen2DecoderLayer, |
|
Qwen2FlashAttention2, |
|
Qwen2ForCausalLM, |
|
Qwen2ForSequenceClassification, |
|
Qwen2Model, |
|
Qwen2SdpaAttention, |
|
) |
|
except ImportError: |
|
Qwen2ForCausalLM = "Qwen2ForCausalLM" |
|
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification" |
|
Qwen2Attention = "Qwen2Attention" |
|
Qwen2FlashAttention2 = "Qwen2FlashAttention2" |
|
Qwen2SdpaAttention = "Qwen2SdpaAttention" |
|
Qwen2DecoderLayer = "Qwen2DecoderLayer" |
|
Qwen2Model = "Qwen2Model" |
|
|
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription |
|
|
|
__all__ = ["Qwen2Policy", "Qwen2ForCausalLMPolicy", "Qwen2ForSequenceClassificationPolicy"] |
|
|
|
|
|
class Qwen2Policy(Policy): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
import transformers |
|
from packaging.version import Version |
|
|
|
assert Version(transformers.__version__) >= Version( |
|
"4.39.1" |
|
), "The Qwen2 model should run on a transformers version of 4.39.1." |
|
|
|
def config_sanity_check(self): |
|
pass |
|
|
|
def preprocess(self): |
|
self.tie_weight = self.tie_weight_check() |
|
self.origin_attn_implement = self.model.config._attn_implementation |
|
return self.model |
|
|
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: |
|
ATTN_IMPLEMENTATION = { |
|
"eager": Qwen2Attention, |
|
"flash_attention_2": Qwen2FlashAttention2, |
|
"sdpa": Qwen2SdpaAttention, |
|
} |
|
|
|
policy = {} |
|
|
|
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] |
|
embedding_cls = None |
|
if self.shard_config.enable_tensor_parallelism: |
|
embedding_cls = VocabParallelEmbedding1D |
|
else: |
|
if self.tie_weight: |
|
embedding_cls = PaddingEmbedding |
|
norm_cls = FusedRMSNorm if self.shard_config.enable_fused_normalization else RMSNorm |
|
|
|
sp_mode = self.shard_config.sequence_parallelism_mode or None |
|
sp_size = self.shard_config.sequence_parallel_size or None |
|
sp_group = self.shard_config.sequence_parallel_process_group or None |
|
sp_partial_derived = sp_mode in ["split_gather", "ring"] |
|
if sp_mode == "all_to_all": |
|
decoder_attribute_replacement = { |
|
"num_heads": self.model.config.num_attention_heads // sp_size, |
|
} |
|
if getattr(self.model.config, "num_key_value_heads", False): |
|
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size |
|
|
|
policy[attn_cls] = ModulePolicyDescription( |
|
attribute_replacement=decoder_attribute_replacement, |
|
) |
|
|
|
if self.shard_config.enable_tensor_parallelism: |
|
assert ( |
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 |
|
), f"The number of attention heads must be divisible by tensor parallel size." |
|
if hasattr(self.model.config, "num_key_value_heads"): |
|
assert ( |
|
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 |
|
), f"The number of key_value heads must be divisible by tensor parallel size." |
|
decoder_attribute_replacement = { |
|
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, |
|
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, |
|
} |
|
if getattr(self.model.config, "num_key_value_heads", False): |
|
decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( |
|
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size |
|
) |
|
|
|
policy[Qwen2DecoderLayer] = ModulePolicyDescription( |
|
attribute_replacement=decoder_attribute_replacement, |
|
sub_module_replacement=[ |
|
SubModuleReplacementDescription( |
|
suffix="self_attn.q_proj", |
|
target_module=Linear1D_Col, |
|
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), |
|
), |
|
SubModuleReplacementDescription( |
|
suffix="self_attn.k_proj", |
|
target_module=Linear1D_Col, |
|
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), |
|
), |
|
SubModuleReplacementDescription( |
|
suffix="self_attn.v_proj", |
|
target_module=Linear1D_Col, |
|
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), |
|
), |
|
SubModuleReplacementDescription( |
|
suffix="self_attn.o_proj", |
|
target_module=Linear1D_Row, |
|
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), |
|
), |
|
SubModuleReplacementDescription( |
|
suffix="mlp.gate_proj", |
|
target_module=Linear1D_Col, |
|
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), |
|
), |
|
SubModuleReplacementDescription( |
|
suffix="mlp.up_proj", |
|
target_module=Linear1D_Col, |
|
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), |
|
), |
|
SubModuleReplacementDescription( |
|
suffix="mlp.down_proj", |
|
target_module=Linear1D_Row, |
|
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), |
|
), |
|
], |
|
) |
|
|
|
if embedding_cls is not None: |
|
self.append_or_create_submodule_replacement( |
|
description=SubModuleReplacementDescription( |
|
suffix="embed_tokens", |
|
target_module=embedding_cls, |
|
kwargs=( |
|
{ |
|
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, |
|
"fp8_communication": self.shard_config.fp8_communication, |
|
} |
|
if self.shard_config.enable_tensor_parallelism |
|
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} |
|
), |
|
), |
|
policy=policy, |
|
target_key=Qwen2Model, |
|
) |
|
|
|
# optimization configuration |
|
self.append_or_create_submodule_replacement( |
|
description=[ |
|
SubModuleReplacementDescription( |
|
suffix="input_layernorm", |
|
target_module=norm_cls, |
|
kwargs={"sp_partial_derived": sp_partial_derived}, |
|
), |
|
SubModuleReplacementDescription( |
|
suffix="post_attention_layernorm", |
|
target_module=norm_cls, |
|
kwargs={"sp_partial_derived": sp_partial_derived}, |
|
), |
|
], |
|
policy=policy, |
|
target_key=Qwen2DecoderLayer, |
|
) |
|
|
|
self.append_or_create_submodule_replacement( |
|
description=SubModuleReplacementDescription( |
|
suffix="norm", |
|
target_module=norm_cls, |
|
kwargs={"sp_partial_derived": sp_partial_derived}, |
|
), |
|
policy=policy, |
|
target_key=Qwen2Model, |
|
) |
|
|
|
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: |
|
self.append_or_create_method_replacement( |
|
description={ |
|
"forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), |
|
}, |
|
policy=policy, |
|
target_key=attn_cls, |
|
) |
|
if self.pipeline_stage_manager is None: |
|
# replace qwen2 model forward method |
|
self.append_or_create_method_replacement( |
|
description={ |
|
"forward": get_qwen2_model_forward_for_flash_attn( |
|
self.shard_config, sp_mode, sp_size, sp_group |
|
), |
|
}, |
|
policy=policy, |
|
target_key=Qwen2Model, |
|
) |
|
|
|
return policy |
|
|
|
def postprocess(self): |
|
return self.model |
|
|
|
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: |
|
"""If under pipeline parallel setting, replacing the original forward method of huggingface |
|
to customized forward method, and add this changing to policy.""" |
|
if self.pipeline_stage_manager is None: |
|
return |
|
|
|
stage_manager = self.pipeline_stage_manager |
|
if self.model.__class__.__name__ == "Qwen2Model": |
|
module = self.model |
|
else: |
|
module = self.model.model |
|
|
|
if stage_manager.is_interleave: |
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers)) |
|
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) |
|
method_replacement = { |
|
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) |
|
} |
|
|
|
else: |
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers)) |
|
stage_index = stage_manager.get_stage_index(layers_per_stage) |
|
method_replacement = { |
|
"forward": partial( |
|
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config |
|
) |
|
} |
|
self.append_or_create_method_replacement( |
|
description=method_replacement, policy=policy, target_key=model_cls |
|
) |
|
|
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) |
|
|
|
def get_held_layers(self) -> List[Module]: |
|
"""Get pipeline layers for current stage.""" |
|
assert self.pipeline_stage_manager is not None |
|
|
|
if self.model.__class__.__name__ == "Qwen2Model": |
|
module = self.model |
|
else: |
|
module = self.model.model |
|
|
|
stage_manager = self.pipeline_stage_manager |
|
|
|
held_layers = [] |
|
if stage_manager.is_interleave: |
|
assert stage_manager.num_model_chunks is not None |
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers)) |
|
stage_indices = stage_manager.get_stage_index(layers_per_stage) |
|
if stage_manager.is_first_stage(ignore_chunk=True): |
|
held_layers.append(module.embed_tokens) |
|
for start_idx, end_idx in stage_indices: |
|
held_layers.extend(module.layers[start_idx:end_idx]) |
|
if stage_manager.is_last_stage(ignore_chunk=True): |
|
held_layers.append(module.norm) |
|
|
|
else: |
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers)) |
|
if stage_manager.is_first_stage(): |
|
held_layers.append(module.embed_tokens) |
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) |
|
held_layers.extend(module.layers[start_idx:end_idx]) |
|
if stage_manager.is_last_stage(): |
|
held_layers.append(module.norm) |
|
|
|
return held_layers |
|
|
|
|
|
class Qwen2ModelPolicy(Qwen2Policy): |
|
def module_policy(self): |
|
policy = super().module_policy() |
|
|
|
if self.pipeline_stage_manager: |
|
# set None as default |
|
self.set_pipeline_forward( |
|
model_cls=Qwen2Model, new_forward=Qwen2PipelineForwards.qwen2_model_forward, policy=policy |
|
) |
|
return policy |
|
|
|
def get_held_layers(self) -> List[Module]: |
|
"""Get pipeline layers for current stage.""" |
|
held_layers = super().get_held_layers() |
|
return held_layers |
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]: |
|
"""No shared params in Qwen2 model""" |
|
return [] |
|
|
|
|
|
class Qwen2ForCausalLMPolicy(Qwen2Policy): |
|
def module_policy(self): |
|
policy = super().module_policy() |
|
setattr(self.shard_config, "causal_lm", True) |
|
|
|
if self.shard_config.enable_tensor_parallelism: |
|
# add a new item for casual lm |
|
new_item = { |
|
Qwen2ForCausalLM: ModulePolicyDescription( |
|
sub_module_replacement=[ |
|
SubModuleReplacementDescription( |
|
suffix="lm_head", |
|
target_module=Linear1D_Col, |
|
kwargs=dict(fp8_communication=self.shard_config.fp8_communication), |
|
) |
|
], |
|
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, |
|
) |
|
} |
|
policy.update(new_item) |
|
|
|
if self.pipeline_stage_manager: |
|
# set None as default |
|
self.set_pipeline_forward( |
|
model_cls=Qwen2ForCausalLM, new_forward=Qwen2PipelineForwards.qwen2_for_causal_lm_forward, policy=policy |
|
) |
|
|
|
return policy |
|
|
|
def get_held_layers(self) -> List[Module]: |
|
"""Get pipeline layers for current stage.""" |
|
stage_manager = self.pipeline_stage_manager |
|
held_layers = super().get_held_layers() |
|
if stage_manager.is_last_stage(ignore_chunk=True): |
|
held_layers.append(self.model.lm_head) |
|
return held_layers |
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]: |
|
qwen2_model = self.model.model |
|
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: |
|
if ( |
|
id(qwen2_model.embed_tokens.weight) == id(self.model.lm_head.weight) |
|
and self.pipeline_stage_manager.num_stages > 1 |
|
): |
|
# tie weights |
|
return [ |
|
{ |
|
0: qwen2_model.embed_tokens.weight, |
|
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, |
|
} |
|
] |
|
return [] |
|
|
|
|
|
class Qwen2ForSequenceClassificationPolicy(Qwen2Policy): |
|
def module_policy(self): |
|
policy = super().module_policy() |
|
if self.shard_config.enable_tensor_parallelism: |
|
# add a new item for sequence classification |
|
new_item = { |
|
Qwen2ForSequenceClassification: ModulePolicyDescription( |
|
sub_module_replacement=[ |
|
SubModuleReplacementDescription( |
|
suffix="score", |
|
target_module=Linear1D_Col, |
|
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), |
|
) |
|
] |
|
) |
|
} |
|
policy.update(new_item) |
|
# to be confirmed |
|
if self.pipeline_stage_manager: |
|
# set None as default |
|
self.set_pipeline_forward( |
|
model_cls=Qwen2ForSequenceClassification, |
|
new_forward=Qwen2PipelineForwards.qwen2_for_sequence_classification_forward, |
|
policy=policy, |
|
) |
|
return policy |
|
|
|
def get_held_layers(self) -> List[Module]: |
|
"""Get pipeline layers for current stage.""" |
|
stage_manager = self.pipeline_stage_manager |
|
held_layers = super().get_held_layers() |
|
if stage_manager.is_last_stage(ignore_chunk=True): |
|
held_layers.append(self.model.score) |
|
return held_layers |
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]: |
|
"""No shared params in Qwen2 for sequence classification model""" |
|
return []
|
|
|