[test] add mixtral for sequence classification

moe_sp
hxwang 5 months ago
parent f585d4e38e
commit 229db4bc16
No known key found for this signature in database
GPG Key ID: 0EC383D418F0B9F8

@ -200,6 +200,9 @@ _POLICY_LIST = {
"transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM": PolicyLocation(
file_name="mixtral", class_name="MixtralForCausalLMPolicy"
),
"transformers.models.mixtral.modeling_mixtral.MixtralForSequenceClassification": PolicyLocation(
file_name="mixtral", class_name="MixtralForSequenceClassificationPolicy"
),
# Qwen2
"transformers.models.qwen2.modeling_qwen2.Qwen2Model": PolicyLocation(
file_name="qwen2", class_name="Qwen2ModelPolicy"

@ -1,3 +1,4 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Union
@ -39,20 +40,81 @@ class MixtralPolicy(Policy):
)
if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
if getattr(self.shard_config, "ep_group", None) is not None:
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
kwargs={"ep_group": self.shard_config.ep_group},
)
],
policy=policy,
target_key=MixtralDecoderLayer,
)
raise NotImplementedError
# assert (
# self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
# ), f"The number of attention heads must be divisible by tensor parallel size."
# assert (
# self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
# ), f"The number of key_value heads must be divisible by tensor parallel size."
# decoder_attribute_replacement = {
# "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
# "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
# "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
# // self.shard_config.tensor_parallel_size,
# }
# policy[MixtralDecoderLayer] = ModulePolicyDescription(
# attribute_replacement=decoder_attribute_replacement,
# sub_module_replacement=[
# SubModuleReplacementDescription(
# suffix="self_attn.q_proj",
# target_module=Linear1D_Col,
# kwargs={
# 'process_group': self.shard_config.tensor_parallel_process_group,
# }
# ),
# SubModuleReplacementDescription(
# suffix="self_attn.k_proj",
# target_module=Linear1D_Col,
# kwargs={
# 'process_group': self.shard_config.tensor_parallel_process_group,
# }
# ),
# SubModuleReplacementDescription(
# suffix="self_attn.v_proj",
# target_module=Linear1D_Col,
# kwargs={
# 'process_group': self.shard_config.tensor_parallel_process_group,
# }
# ),
# SubModuleReplacementDescription(
# suffix="self_attn.o_proj",
# target_module=Linear1D_Row,
# kwargs={
# 'process_group': self.shard_config.tensor_parallel_process_group,
# }
# ),
# # SubModuleReplacementDescription(
# # suffix="mlp.gate_proj",
# # target_module=Linear1D_Col,
# # ),
# # SubModuleReplacementDescription(
# # suffix="mlp.up_proj",
# # target_module=Linear1D_Col,
# # ),
# # SubModuleReplacementDescription(
# # suffix="mlp.down_proj",
# # target_module=Linear1D_Row,
# # ),
# ],
# )
if getattr(self.shard_config, "ep_group", None) is None:
raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
kwargs={"ep_group": self.shard_config.ep_group},
)
],
policy=policy,
target_key=MixtralDecoderLayer,
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
@ -81,7 +143,7 @@ class MixtralPolicy(Policy):
)
if self.shard_config.enable_flash_attention:
raise NotImplementedError("Flash attention has already been replaced in mixtral.")
warnings.warn("Flash attention is natively supported in transformers, will ignore the flag.")
return policy
@ -150,7 +212,7 @@ class MixtralModelPolicy(MixtralPolicy):
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama model"""
"""No shared params in mixtral model"""
return []
@ -206,3 +268,40 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
}
]
return []
class MixtralForSequenceClassificationPolicy(MixtralPolicy):
def module_policy(self):
from transformers import MixtralForSequenceClassification
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification
new_item = {
MixtralForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
)
]
)
}
policy.update(new_item)
if self.pipeline_stage_manager:
raise NotImplementedError
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.score)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama for sequence classification model"""
return []

Loading…
Cancel
Save