ColossalAI/colossalai/shardformer/policies/deepseek_v3.py

165 lines
6.7 KiB
Python

from functools import partial
from typing import Callable, Dict, List, Union
import torch.nn as nn
from colossalai.shardformer.layer import FusedRMSNorm
from colossalai.shardformer.modeling.deepseek_v3 import (
EpDeepseekV3MoE,
deepseek_v3_for_causal_lm_forward,
deepseek_v3_model_forward,
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
class DeepseekV3Policy(Policy):
def config_sanity_check(self):
assert not self.shard_config.enable_tensor_parallelism, "DeepSeekV3 does not support tensor parallelism"
assert not self.shard_config.enable_sequence_parallelism, "DeepSeekV3 does not support sequence parallelism"
if self.shard_config.pipeline_stage_manager:
assert not self.shard_config.pipeline_stage_manager.use_zbv, "DeepSeekV3 does not support ZBV"
def preprocess(self):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
# support gradient checkpointing
if self.shard_config.pipeline_stage_manager is None:
policy["DeepseekV3Model"] = ModulePolicyDescription(
method_replacement={"forward": deepseek_v3_model_forward}
)
if self.shard_config.expert_parallel_size > 1:
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="mlp",
target_module=EpDeepseekV3MoE,
kwargs={
"ep_group": self.shard_config.ep_group,
"moe_dp_group": self.shard_config.moe_dp_group,
},
)
],
policy=policy,
target_key="DeepseekV3DecoderLayer",
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
# TODO: prevent casting to fp32
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
),
],
policy=policy,
target_key="DeepseekV3DecoderLayer",
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key="DeepseekV3Model",
)
return policy
def postprocess(self):
return self.model
def set_pipeline_forward(self, model_cls: str, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
num_layers = self.model.config.num_hidden_layers
stage_manager = self.pipeline_stage_manager
layers_per_stage = stage_manager.distribute_layers(num_layers)
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
return
def get_held_layers(self) -> List[nn.Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
module = self.model
if module.__class__.__name__.startswith("PeftModel"):
module = module.get_base_model()
if module.__class__.__name__ != "DeepseekV3Model":
module = module.model
stage_manager = self.pipeline_stage_manager
held_layers = []
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_indices = stage_manager.get_stage_index(layers_per_stage)
stage_manager.stage_indices = stage_indices
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
# for zbv, when is_first_stage (last fwd), we append norm
# for interleaved, when is_last_stage (last fwd), we also append norm
held_layers.append(module.norm)
else:
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
return held_layers
class DeepseekV3ModelPolicy(DeepseekV3Policy):
def module_policy(self):
policy = super().module_policy()
if self.shard_config.pipeline_stage_manager:
self.set_pipeline_forward("DeepseekV3Model", deepseek_v3_model_forward, policy)
return policy
class DeepseekV3ForCausalLMPolicy(DeepseekV3Policy):
def module_policy(self):
policy = super().module_policy()
if self.shard_config.pipeline_stage_manager:
self.set_pipeline_forward("DeepseekV3ForCausalLM", deepseek_v3_for_causal_lm_forward, policy)
return policy
def get_held_layers(self):
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
return held_layers