mirror of https://github.com/hpcaitech/ColossalAI
[test] add mixtral for sequence classification
parent
f585d4e38e
commit
229db4bc16
|
@ -200,6 +200,9 @@ _POLICY_LIST = {
|
||||||
"transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM": PolicyLocation(
|
"transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM": PolicyLocation(
|
||||||
file_name="mixtral", class_name="MixtralForCausalLMPolicy"
|
file_name="mixtral", class_name="MixtralForCausalLMPolicy"
|
||||||
),
|
),
|
||||||
|
"transformers.models.mixtral.modeling_mixtral.MixtralForSequenceClassification": PolicyLocation(
|
||||||
|
file_name="mixtral", class_name="MixtralForSequenceClassificationPolicy"
|
||||||
|
),
|
||||||
# Qwen2
|
# Qwen2
|
||||||
"transformers.models.qwen2.modeling_qwen2.Qwen2Model": PolicyLocation(
|
"transformers.models.qwen2.modeling_qwen2.Qwen2Model": PolicyLocation(
|
||||||
file_name="qwen2", class_name="Qwen2ModelPolicy"
|
file_name="qwen2", class_name="Qwen2ModelPolicy"
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Dict, List, Union
|
from typing import Callable, Dict, List, Union
|
||||||
|
|
||||||
|
@ -39,20 +40,81 @@ class MixtralPolicy(Policy):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
|
raise NotImplementedError
|
||||||
if getattr(self.shard_config, "ep_group", None) is not None:
|
# assert (
|
||||||
# expert parallel
|
# self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
self.append_or_create_submodule_replacement(
|
# ), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
description=[
|
# assert (
|
||||||
SubModuleReplacementDescription(
|
# self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
suffix="block_sparse_moe",
|
# ), f"The number of key_value heads must be divisible by tensor parallel size."
|
||||||
target_module=EPMixtralSparseMoeBlock,
|
# decoder_attribute_replacement = {
|
||||||
kwargs={"ep_group": self.shard_config.ep_group},
|
# "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
)
|
# "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
],
|
# "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
|
||||||
policy=policy,
|
# // self.shard_config.tensor_parallel_size,
|
||||||
target_key=MixtralDecoderLayer,
|
# }
|
||||||
)
|
|
||||||
|
# policy[MixtralDecoderLayer] = ModulePolicyDescription(
|
||||||
|
# attribute_replacement=decoder_attribute_replacement,
|
||||||
|
# sub_module_replacement=[
|
||||||
|
# SubModuleReplacementDescription(
|
||||||
|
# suffix="self_attn.q_proj",
|
||||||
|
# target_module=Linear1D_Col,
|
||||||
|
# kwargs={
|
||||||
|
# 'process_group': self.shard_config.tensor_parallel_process_group,
|
||||||
|
# }
|
||||||
|
# ),
|
||||||
|
# SubModuleReplacementDescription(
|
||||||
|
# suffix="self_attn.k_proj",
|
||||||
|
# target_module=Linear1D_Col,
|
||||||
|
# kwargs={
|
||||||
|
# 'process_group': self.shard_config.tensor_parallel_process_group,
|
||||||
|
# }
|
||||||
|
# ),
|
||||||
|
# SubModuleReplacementDescription(
|
||||||
|
# suffix="self_attn.v_proj",
|
||||||
|
# target_module=Linear1D_Col,
|
||||||
|
# kwargs={
|
||||||
|
# 'process_group': self.shard_config.tensor_parallel_process_group,
|
||||||
|
# }
|
||||||
|
# ),
|
||||||
|
# SubModuleReplacementDescription(
|
||||||
|
# suffix="self_attn.o_proj",
|
||||||
|
# target_module=Linear1D_Row,
|
||||||
|
# kwargs={
|
||||||
|
# 'process_group': self.shard_config.tensor_parallel_process_group,
|
||||||
|
# }
|
||||||
|
# ),
|
||||||
|
# # SubModuleReplacementDescription(
|
||||||
|
# # suffix="mlp.gate_proj",
|
||||||
|
# # target_module=Linear1D_Col,
|
||||||
|
# # ),
|
||||||
|
# # SubModuleReplacementDescription(
|
||||||
|
# # suffix="mlp.up_proj",
|
||||||
|
# # target_module=Linear1D_Col,
|
||||||
|
# # ),
|
||||||
|
# # SubModuleReplacementDescription(
|
||||||
|
# # suffix="mlp.down_proj",
|
||||||
|
# # target_module=Linear1D_Row,
|
||||||
|
# # ),
|
||||||
|
# ],
|
||||||
|
# )
|
||||||
|
|
||||||
|
if getattr(self.shard_config, "ep_group", None) is None:
|
||||||
|
raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
|
||||||
|
|
||||||
|
# expert parallel
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="block_sparse_moe",
|
||||||
|
target_module=EPMixtralSparseMoeBlock,
|
||||||
|
kwargs={"ep_group": self.shard_config.ep_group},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
policy=policy,
|
||||||
|
target_key=MixtralDecoderLayer,
|
||||||
|
)
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
@ -81,7 +143,7 @@ class MixtralPolicy(Policy):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.shard_config.enable_flash_attention:
|
if self.shard_config.enable_flash_attention:
|
||||||
raise NotImplementedError("Flash attention has already been replaced in mixtral.")
|
warnings.warn("Flash attention is natively supported in transformers, will ignore the flag.")
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
@ -150,7 +212,7 @@ class MixtralModelPolicy(MixtralPolicy):
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
"""No shared params in llama model"""
|
"""No shared params in mixtral model"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
@ -206,3 +268,40 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralForSequenceClassificationPolicy(MixtralPolicy):
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers import MixtralForSequenceClassification
|
||||||
|
|
||||||
|
policy = super().module_policy()
|
||||||
|
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
# add a new item for sequence classification
|
||||||
|
new_item = {
|
||||||
|
MixtralForSequenceClassification: ModulePolicyDescription(
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
policy.update(new_item)
|
||||||
|
|
||||||
|
if self.pipeline_stage_manager:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
"""Get pipeline layers for current stage."""
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(self.model.score)
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
"""No shared params in llama for sequence classification model"""
|
||||||
|
return []
|
||||||
|
|
Loading…
Reference in New Issue