ColossalAI/colossalai/shardformer/policies/deepseek_v3.py

165 lines
6.7 KiB
Python
Raw Permalink Normal View History

from functools import partial
from typing import Callable, Dict, List, Union
import torch.nn as nn
from colossalai.shardformer.layer import FusedRMSNorm
from colossalai.shardformer.modeling.deepseek_v3 import (
EpDeepseekV3MoE,
deepseek_v3_for_causal_lm_forward,
deepseek_v3_model_forward,
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
class DeepseekV3Policy(Policy):
def config_sanity_check(self):
assert not self.shard_config.enable_tensor_parallelism, "DeepSeekV3 does not support tensor parallelism"
assert not self.shard_config.enable_sequence_parallelism, "DeepSeekV3 does not support sequence parallelism"
if self.shard_config.pipeline_stage_manager:
assert not self.shard_config.pipeline_stage_manager.use_zbv, "DeepSeekV3 does not support ZBV"
def preprocess(self):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
# support gradient checkpointing
if self.shard_config.pipeline_stage_manager is None:
policy["DeepseekV3Model"] = ModulePolicyDescription(
method_replacement={"forward": deepseek_v3_model_forward}
)
if self.shard_config.expert_parallel_size > 1:
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="mlp",
target_module=EpDeepseekV3MoE,
kwargs={
"ep_group": self.shard_config.ep_group,
"moe_dp_group": self.shard_config.moe_dp_group,
},
)
],
policy=policy,
target_key="DeepseekV3DecoderLayer",
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
# TODO: prevent casting to fp32
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
),
],
policy=policy,
target_key="DeepseekV3DecoderLayer",
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key="DeepseekV3Model",
)
return policy
def postprocess(self):
return self.model
def set_pipeline_forward(self, model_cls: str, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
num_layers = self.model.config.num_hidden_layers
stage_manager = self.pipeline_stage_manager
layers_per_stage = stage_manager.distribute_layers(num_layers)
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
return
def get_held_layers(self) -> List[nn.Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
module = self.model
if module.__class__.__name__.startswith("PeftModel"):
module = module.get_base_model()
if module.__class__.__name__ != "DeepseekV3Model":
module = module.model
stage_manager = self.pipeline_stage_manager
held_layers = []
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_indices = stage_manager.get_stage_index(layers_per_stage)
stage_manager.stage_indices = stage_indices
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
# for zbv, when is_first_stage (last fwd), we append norm
# for interleaved, when is_last_stage (last fwd), we also append norm
held_layers.append(module.norm)
else:
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
return held_layers
class DeepseekV3ModelPolicy(DeepseekV3Policy):
def module_policy(self):
policy = super().module_policy()
if self.shard_config.pipeline_stage_manager:
self.set_pipeline_forward("DeepseekV3Model", deepseek_v3_model_forward, policy)
return policy
class DeepseekV3ForCausalLMPolicy(DeepseekV3Policy):
def module_policy(self):
policy = super().module_policy()
if self.shard_config.pipeline_stage_manager:
self.set_pipeline_forward("DeepseekV3ForCausalLM", deepseek_v3_for_causal_lm_forward, policy)
return policy
def get_held_layers(self):
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
return held_layers