diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index ae9f3603c..1e0af031a 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -200,6 +200,9 @@ _POLICY_LIST = { "transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM": PolicyLocation( file_name="mixtral", class_name="MixtralForCausalLMPolicy" ), + "transformers.models.mixtral.modeling_mixtral.MixtralForSequenceClassification": PolicyLocation( + file_name="mixtral", class_name="MixtralForSequenceClassificationPolicy" + ), # Qwen2 "transformers.models.qwen2.modeling_qwen2.Qwen2Model": PolicyLocation( file_name="qwen2", class_name="Qwen2ModelPolicy" diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index ad93e9469..e3cc48043 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -39,20 +40,81 @@ class MixtralPolicy(Policy): ) if self.shard_config.enable_tensor_parallelism: - raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.") - if getattr(self.shard_config, "ep_group", None) is not None: - # expert parallel - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="block_sparse_moe", - target_module=EPMixtralSparseMoeBlock, - kwargs={"ep_group": self.shard_config.ep_group}, - ) - ], - policy=policy, - target_key=MixtralDecoderLayer, - ) + raise NotImplementedError + # assert ( + # self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + # ), f"The number of attention heads must be divisible by tensor parallel size." + # assert ( + # self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 + # ), f"The number of key_value heads must be divisible by tensor parallel size." + # decoder_attribute_replacement = { + # "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + # "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + # "self_attn.num_key_value_heads": self.model.config.num_key_value_heads + # // self.shard_config.tensor_parallel_size, + # } + + # policy[MixtralDecoderLayer] = ModulePolicyDescription( + # attribute_replacement=decoder_attribute_replacement, + # sub_module_replacement=[ + # SubModuleReplacementDescription( + # suffix="self_attn.q_proj", + # target_module=Linear1D_Col, + # kwargs={ + # 'process_group': self.shard_config.tensor_parallel_process_group, + # } + # ), + # SubModuleReplacementDescription( + # suffix="self_attn.k_proj", + # target_module=Linear1D_Col, + # kwargs={ + # 'process_group': self.shard_config.tensor_parallel_process_group, + # } + # ), + # SubModuleReplacementDescription( + # suffix="self_attn.v_proj", + # target_module=Linear1D_Col, + # kwargs={ + # 'process_group': self.shard_config.tensor_parallel_process_group, + # } + # ), + # SubModuleReplacementDescription( + # suffix="self_attn.o_proj", + # target_module=Linear1D_Row, + # kwargs={ + # 'process_group': self.shard_config.tensor_parallel_process_group, + # } + # ), + # # SubModuleReplacementDescription( + # # suffix="mlp.gate_proj", + # # target_module=Linear1D_Col, + # # ), + # # SubModuleReplacementDescription( + # # suffix="mlp.up_proj", + # # target_module=Linear1D_Col, + # # ), + # # SubModuleReplacementDescription( + # # suffix="mlp.down_proj", + # # target_module=Linear1D_Row, + # # ), + # ], + # ) + + if getattr(self.shard_config, "ep_group", None) is None: + raise ValueError("You must pass in ep_group via shard_config for expert parallel!") + + # expert parallel + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="block_sparse_moe", + target_module=EPMixtralSparseMoeBlock, + kwargs={"ep_group": self.shard_config.ep_group}, + ) + ], + policy=policy, + target_key=MixtralDecoderLayer, + ) # optimization configuration if self.shard_config.enable_fused_normalization: @@ -81,7 +143,7 @@ class MixtralPolicy(Policy): ) if self.shard_config.enable_flash_attention: - raise NotImplementedError("Flash attention has already been replaced in mixtral.") + warnings.warn("Flash attention is natively supported in transformers, will ignore the flag.") return policy @@ -150,7 +212,7 @@ class MixtralModelPolicy(MixtralPolicy): return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - """No shared params in llama model""" + """No shared params in mixtral model""" return [] @@ -206,3 +268,40 @@ class MixtralForCausalLMPolicy(MixtralPolicy): } ] return [] + + +class MixtralForSequenceClassificationPolicy(MixtralPolicy): + def module_policy(self): + from transformers import MixtralForSequenceClassification + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for sequence classification + new_item = { + MixtralForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + raise NotImplementedError + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama for sequence classification model""" + return []