mirror of https://github.com/hpcaitech/ColossalAI
[test] add mixtral for sequence classification
parent
f585d4e38e
commit
229db4bc16
|
@ -200,6 +200,9 @@ _POLICY_LIST = {
|
|||
"transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM": PolicyLocation(
|
||||
file_name="mixtral", class_name="MixtralForCausalLMPolicy"
|
||||
),
|
||||
"transformers.models.mixtral.modeling_mixtral.MixtralForSequenceClassification": PolicyLocation(
|
||||
file_name="mixtral", class_name="MixtralForSequenceClassificationPolicy"
|
||||
),
|
||||
# Qwen2
|
||||
"transformers.models.qwen2.modeling_qwen2.Qwen2Model": PolicyLocation(
|
||||
file_name="qwen2", class_name="Qwen2ModelPolicy"
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import warnings
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
|
@ -39,20 +40,81 @@ class MixtralPolicy(Policy):
|
|||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
|
||||
if getattr(self.shard_config, "ep_group", None) is not None:
|
||||
# expert parallel
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="block_sparse_moe",
|
||||
target_module=EPMixtralSparseMoeBlock,
|
||||
kwargs={"ep_group": self.shard_config.ep_group},
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=MixtralDecoderLayer,
|
||||
)
|
||||
raise NotImplementedError
|
||||
# assert (
|
||||
# self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
# ), f"The number of attention heads must be divisible by tensor parallel size."
|
||||
# assert (
|
||||
# self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
||||
# ), f"The number of key_value heads must be divisible by tensor parallel size."
|
||||
# decoder_attribute_replacement = {
|
||||
# "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
# "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
# "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
|
||||
# // self.shard_config.tensor_parallel_size,
|
||||
# }
|
||||
|
||||
# policy[MixtralDecoderLayer] = ModulePolicyDescription(
|
||||
# attribute_replacement=decoder_attribute_replacement,
|
||||
# sub_module_replacement=[
|
||||
# SubModuleReplacementDescription(
|
||||
# suffix="self_attn.q_proj",
|
||||
# target_module=Linear1D_Col,
|
||||
# kwargs={
|
||||
# 'process_group': self.shard_config.tensor_parallel_process_group,
|
||||
# }
|
||||
# ),
|
||||
# SubModuleReplacementDescription(
|
||||
# suffix="self_attn.k_proj",
|
||||
# target_module=Linear1D_Col,
|
||||
# kwargs={
|
||||
# 'process_group': self.shard_config.tensor_parallel_process_group,
|
||||
# }
|
||||
# ),
|
||||
# SubModuleReplacementDescription(
|
||||
# suffix="self_attn.v_proj",
|
||||
# target_module=Linear1D_Col,
|
||||
# kwargs={
|
||||
# 'process_group': self.shard_config.tensor_parallel_process_group,
|
||||
# }
|
||||
# ),
|
||||
# SubModuleReplacementDescription(
|
||||
# suffix="self_attn.o_proj",
|
||||
# target_module=Linear1D_Row,
|
||||
# kwargs={
|
||||
# 'process_group': self.shard_config.tensor_parallel_process_group,
|
||||
# }
|
||||
# ),
|
||||
# # SubModuleReplacementDescription(
|
||||
# # suffix="mlp.gate_proj",
|
||||
# # target_module=Linear1D_Col,
|
||||
# # ),
|
||||
# # SubModuleReplacementDescription(
|
||||
# # suffix="mlp.up_proj",
|
||||
# # target_module=Linear1D_Col,
|
||||
# # ),
|
||||
# # SubModuleReplacementDescription(
|
||||
# # suffix="mlp.down_proj",
|
||||
# # target_module=Linear1D_Row,
|
||||
# # ),
|
||||
# ],
|
||||
# )
|
||||
|
||||
if getattr(self.shard_config, "ep_group", None) is None:
|
||||
raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
|
||||
|
||||
# expert parallel
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="block_sparse_moe",
|
||||
target_module=EPMixtralSparseMoeBlock,
|
||||
kwargs={"ep_group": self.shard_config.ep_group},
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=MixtralDecoderLayer,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
|
@ -81,7 +143,7 @@ class MixtralPolicy(Policy):
|
|||
)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
raise NotImplementedError("Flash attention has already been replaced in mixtral.")
|
||||
warnings.warn("Flash attention is natively supported in transformers, will ignore the flag.")
|
||||
|
||||
return policy
|
||||
|
||||
|
@ -150,7 +212,7 @@ class MixtralModelPolicy(MixtralPolicy):
|
|||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama model"""
|
||||
"""No shared params in mixtral model"""
|
||||
return []
|
||||
|
||||
|
||||
|
@ -206,3 +268,40 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
|||
}
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
class MixtralForSequenceClassificationPolicy(MixtralPolicy):
|
||||
def module_policy(self):
|
||||
from transformers import MixtralForSequenceClassification
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for sequence classification
|
||||
new_item = {
|
||||
MixtralForSequenceClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
raise NotImplementedError
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.score)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama for sequence classification model"""
|
||||
return []
|
||||
|
|
Loading…
Reference in New Issue