import warnings from functools import partial from typing import Callable, Dict, List, Union import torch.nn as nn from torch import Tensor from torch.nn import Module from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col from colossalai.shardformer.layer.linear import Linear1D_Row from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"] class DeepseekPolicy(Policy): def config_sanity_check(self): pass def preprocess(self): if self.shard_config.enable_tensor_parallelism: # Resize embedding vocab_size = self.model.config.vocab_size world_size = self.shard_config.tensor_parallel_size if vocab_size % world_size != 0: new_vocab_size = vocab_size + world_size - vocab_size % world_size self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} if self.shard_config.enable_sequence_parallelism: if self.pipeline_stage_manager is not None: # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism # if both are enabled, one of them will be ignored raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.") raise NotImplementedError( "Deepseek dosen't support sequence parallelism now, will ignore the sequence parallelism flag." ) if self.shard_config.enable_tensor_parallelism: # tensor parallelism for non-moe params assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 ), f"The number of attention heads must be divisible by tensor parallel size." assert ( self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 ), f"The number of key_value heads must be divisible by tensor parallel size." decoder_attribute_replacement = { "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "self_attn.num_key_value_heads": self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size, } policy["DeepseekDecoderLayer"] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, sub_module_replacement=[ SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, ), ], ) if self.shard_config.ep_group: # expert parallel self.append_or_create_submodule_replacement( description=[ SubModuleReplacementDescription( suffix="mlp", target_module=EPDeepseekMoE, kwargs={ "ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, "moe_tp_group": self.shard_config.moe_tp_group, }, ) ], policy=policy, target_key="DeepseekDecoderLayer", ) # optimization configuration if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement( description=[ SubModuleReplacementDescription( suffix="input_layernorm", target_module=FusedRMSNorm, ), SubModuleReplacementDescription( suffix="post_attention_layernorm", target_module=FusedRMSNorm, ), ], policy=policy, target_key="DeepseekDecoderLayer", ) self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="norm", target_module=FusedRMSNorm, ), policy=policy, target_key="DeepseekModel", ) if self.shard_config.enable_flash_attention: warnings.warn( "Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False." ) self.shard_config.enable_flash_attention = False return policy def postprocess(self): return self.model def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface to customized forward method, and add this changing to policy.""" if self.pipeline_stage_manager: if self.shard_config.enable_sequence_parallelism: # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism # if both are enabled, one of them will be ignored raise NotImplementedError("Pipeline parallelism is not supported with sequence parallelism.") stage_manager = self.pipeline_stage_manager if self.model.__class__.__name__ == "DeepseekModel": module = self.model else: module = self.model.model layers_per_stage = stage_manager.distribute_layers(len(module.layers)) stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls ) return def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None if self.model.__class__.__name__ == "DeepseekModel": module = self.model else: module = self.model.model stage_manager = self.pipeline_stage_manager held_layers = [] layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): held_layers.append(module.embed_tokens) start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.norm) return held_layers class DeepseekModelPolicy(DeepseekPolicy): def __init__(self) -> None: super().__init__() def module_policy(self): policy = super().module_policy() if self.pipeline_stage_manager: # set None as default self.set_pipeline_forward( model_cls="DeepseekModel", new_forward=DeepseekPipelineForwards.deepseek_model_forward, policy=policy, ) return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" held_layers = super().get_held_layers() return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: """No shared params in llama model""" return [] class DeepseekForCausalLMPolicy(DeepseekPolicy): def module_policy(self): policy = super().module_policy() # TODO: assign pg mesh from plugin to all modules if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm new_item = { "DeepseekForCausalLM": ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True), ) ] ) } policy.update(new_item) if self.pipeline_stage_manager: # set None as default self.set_pipeline_forward( model_cls="DeepseekForCausalLM", new_forward=DeepseekPipelineForwards.deepseek_for_causal_lm_forward, policy=policy, ) return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() if stage_manager.is_last_stage(): held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: deepseek_model = self.model.model if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if ( id(deepseek_model.embed_tokens.weight) == id(self.model.lm_head.weight) and self.pipeline_stage_manager.num_stages > 1 ): # tie weights return [ { 0: deepseek_model.embed_tokens.weight, self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, } ] return []