2024-07-05 08:13:58 +00:00
|
|
|
import warnings
|
|
|
|
from functools import partial
|
|
|
|
from typing import Callable, Dict, List, Union
|
|
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch import Tensor
|
|
|
|
from torch.nn import Module
|
|
|
|
|
|
|
|
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
2024-07-16 10:10:40 +00:00
|
|
|
from colossalai.shardformer.layer.linear import Linear1D_Row
|
2024-07-05 08:13:58 +00:00
|
|
|
from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE
|
|
|
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
|
|
|
|
|
|
|
__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
|
|
|
|
|
|
|
|
|
|
|
|
class DeepseekPolicy(Policy):
|
|
|
|
def config_sanity_check(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def preprocess(self):
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
# Resize embedding
|
|
|
|
vocab_size = self.model.config.vocab_size
|
|
|
|
world_size = self.shard_config.tensor_parallel_size
|
|
|
|
|
|
|
|
if vocab_size % world_size != 0:
|
|
|
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
|
|
|
self.model.resize_token_embeddings(new_vocab_size)
|
|
|
|
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
|
|
|
policy = {}
|
|
|
|
|
|
|
|
if self.shard_config.enable_sequence_parallelism:
|
|
|
|
self.shard_config.enable_sequence_parallelism = False
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Deepseek dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
|
|
|
|
)
|
|
|
|
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
2024-07-16 10:10:40 +00:00
|
|
|
# tensor parallelism for non-moe params
|
|
|
|
assert (
|
|
|
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
|
|
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
|
|
|
assert (
|
|
|
|
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
|
|
|
), f"The number of key_value heads must be divisible by tensor parallel size."
|
|
|
|
decoder_attribute_replacement = {
|
|
|
|
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
|
|
|
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
|
|
|
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
|
|
|
|
// self.shard_config.tensor_parallel_size,
|
|
|
|
}
|
2024-07-05 08:13:58 +00:00
|
|
|
|
2024-07-16 10:10:40 +00:00
|
|
|
policy["DeepseekDecoderLayer"] = ModulePolicyDescription(
|
|
|
|
attribute_replacement=decoder_attribute_replacement,
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.q_proj",
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.k_proj",
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.v_proj",
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.o_proj",
|
|
|
|
target_module=Linear1D_Row,
|
|
|
|
),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
if self.shard_config.ep_group:
|
2024-07-05 08:13:58 +00:00
|
|
|
# expert parallel
|
|
|
|
self.append_or_create_submodule_replacement(
|
|
|
|
description=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="mlp",
|
|
|
|
target_module=EPDeepseekMoE,
|
2024-07-16 10:10:40 +00:00
|
|
|
kwargs={
|
|
|
|
"ep_group": self.shard_config.ep_group,
|
|
|
|
"tp_group": self.shard_config.tensor_parallel_process_group,
|
|
|
|
"moe_dp_group": self.shard_config.moe_dp_group,
|
|
|
|
"moe_tp_group": self.shard_config.moe_tp_group,
|
|
|
|
},
|
2024-07-05 08:13:58 +00:00
|
|
|
)
|
|
|
|
],
|
|
|
|
policy=policy,
|
|
|
|
target_key="DeepseekDecoderLayer",
|
|
|
|
)
|
|
|
|
|
|
|
|
# optimization configuration
|
|
|
|
if self.shard_config.enable_fused_normalization:
|
|
|
|
self.append_or_create_submodule_replacement(
|
|
|
|
description=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="input_layernorm",
|
|
|
|
target_module=FusedRMSNorm,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="post_attention_layernorm",
|
|
|
|
target_module=FusedRMSNorm,
|
|
|
|
),
|
|
|
|
],
|
|
|
|
policy=policy,
|
|
|
|
target_key="DeepseekDecoderLayer",
|
|
|
|
)
|
|
|
|
|
|
|
|
self.append_or_create_submodule_replacement(
|
|
|
|
description=SubModuleReplacementDescription(
|
|
|
|
suffix="norm",
|
|
|
|
target_module=FusedRMSNorm,
|
|
|
|
),
|
|
|
|
policy=policy,
|
|
|
|
target_key="DeepseekModel",
|
|
|
|
)
|
|
|
|
|
|
|
|
if self.shard_config.enable_flash_attention:
|
|
|
|
warnings.warn(
|
|
|
|
"Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False."
|
|
|
|
)
|
|
|
|
self.shard_config.enable_flash_attention = False
|
|
|
|
|
|
|
|
return policy
|
|
|
|
|
|
|
|
def postprocess(self):
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
|
|
|
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
|
|
|
to customized forward method, and add this changing to policy."""
|
|
|
|
if self.pipeline_stage_manager:
|
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
if self.model.__class__.__name__ == "DeepseekModel":
|
|
|
|
module = self.model
|
|
|
|
else:
|
|
|
|
module = self.model.model
|
|
|
|
|
|
|
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
|
|
|
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
|
|
|
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
|
|
|
self.append_or_create_method_replacement(
|
|
|
|
description=method_replacement, policy=policy, target_key=model_cls
|
|
|
|
)
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
|
|
"""Get pipeline layers for current stage."""
|
|
|
|
assert self.pipeline_stage_manager is not None
|
|
|
|
|
|
|
|
if self.model.__class__.__name__ == "DeepseekModel":
|
|
|
|
module = self.model
|
|
|
|
else:
|
|
|
|
module = self.model.model
|
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
|
|
|
|
held_layers = []
|
|
|
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
|
|
|
if stage_manager.is_first_stage():
|
|
|
|
held_layers.append(module.embed_tokens)
|
|
|
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
|
|
|
held_layers.extend(module.layers[start_idx:end_idx])
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
held_layers.append(module.norm)
|
|
|
|
|
|
|
|
return held_layers
|
|
|
|
|
|
|
|
|
|
|
|
class DeepseekModelPolicy(DeepseekPolicy):
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def module_policy(self):
|
|
|
|
policy = super().module_policy()
|
|
|
|
if self.pipeline_stage_manager:
|
|
|
|
# set None as default
|
|
|
|
self.set_pipeline_forward(
|
|
|
|
model_cls="DeepseekModel",
|
|
|
|
new_forward=DeepseekPipelineForwards.deepseek_model_forward,
|
|
|
|
policy=policy,
|
|
|
|
)
|
|
|
|
return policy
|
|
|
|
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
|
|
"""Get pipeline layers for current stage."""
|
|
|
|
held_layers = super().get_held_layers()
|
|
|
|
return held_layers
|
|
|
|
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
|
|
"""No shared params in llama model"""
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
class DeepseekForCausalLMPolicy(DeepseekPolicy):
|
|
|
|
def module_policy(self):
|
|
|
|
policy = super().module_policy()
|
|
|
|
# TODO: assign pg mesh from plugin to all modules
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
# add a new item for casual lm
|
|
|
|
new_item = {
|
|
|
|
"DeepseekForCausalLM": ModulePolicyDescription(
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="lm_head",
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
kwargs=dict(gather_output=True),
|
|
|
|
)
|
|
|
|
]
|
|
|
|
)
|
|
|
|
}
|
|
|
|
policy.update(new_item)
|
|
|
|
|
|
|
|
if self.pipeline_stage_manager:
|
|
|
|
# set None as default
|
|
|
|
self.set_pipeline_forward(
|
|
|
|
model_cls="DeepseekForCausalLM",
|
|
|
|
new_forward=DeepseekPipelineForwards.deepseek_for_causal_lm_forward,
|
|
|
|
policy=policy,
|
|
|
|
)
|
|
|
|
|
|
|
|
return policy
|
|
|
|
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
|
|
"""Get pipeline layers for current stage."""
|
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
held_layers = super().get_held_layers()
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
held_layers.append(self.model.lm_head)
|
|
|
|
return held_layers
|
|
|
|
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
|
|
deepseek_model = self.model.model
|
|
|
|
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
|
|
|
if (
|
|
|
|
id(deepseek_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
|
|
|
and self.pipeline_stage_manager.num_stages > 1
|
|
|
|
):
|
|
|
|
# tie weights
|
|
|
|
return [
|
|
|
|
{
|
|
|
|
0: deepseek_model.embed_tokens.weight,
|
|
|
|
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
|
|
|
}
|
|
|
|
]
|
|
|
|
return []
|