|
|
|
@ -1,3 +1,4 @@
|
|
|
|
|
import warnings
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import Callable, Dict, List, Union
|
|
|
|
|
|
|
|
|
@ -39,20 +40,81 @@ class MixtralPolicy(Policy):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
|
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
|
|
|
|
|
if getattr(self.shard_config, "ep_group", None) is not None:
|
|
|
|
|
# expert parallel
|
|
|
|
|
self.append_or_create_submodule_replacement(
|
|
|
|
|
description=[
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
suffix="block_sparse_moe",
|
|
|
|
|
target_module=EPMixtralSparseMoeBlock,
|
|
|
|
|
kwargs={"ep_group": self.shard_config.ep_group},
|
|
|
|
|
)
|
|
|
|
|
],
|
|
|
|
|
policy=policy,
|
|
|
|
|
target_key=MixtralDecoderLayer,
|
|
|
|
|
)
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
# assert (
|
|
|
|
|
# self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
|
|
|
|
# ), f"The number of attention heads must be divisible by tensor parallel size."
|
|
|
|
|
# assert (
|
|
|
|
|
# self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
|
|
|
|
# ), f"The number of key_value heads must be divisible by tensor parallel size."
|
|
|
|
|
# decoder_attribute_replacement = {
|
|
|
|
|
# "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
|
|
|
|
# "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
|
|
|
|
# "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
|
|
|
|
|
# // self.shard_config.tensor_parallel_size,
|
|
|
|
|
# }
|
|
|
|
|
|
|
|
|
|
# policy[MixtralDecoderLayer] = ModulePolicyDescription(
|
|
|
|
|
# attribute_replacement=decoder_attribute_replacement,
|
|
|
|
|
# sub_module_replacement=[
|
|
|
|
|
# SubModuleReplacementDescription(
|
|
|
|
|
# suffix="self_attn.q_proj",
|
|
|
|
|
# target_module=Linear1D_Col,
|
|
|
|
|
# kwargs={
|
|
|
|
|
# 'process_group': self.shard_config.tensor_parallel_process_group,
|
|
|
|
|
# }
|
|
|
|
|
# ),
|
|
|
|
|
# SubModuleReplacementDescription(
|
|
|
|
|
# suffix="self_attn.k_proj",
|
|
|
|
|
# target_module=Linear1D_Col,
|
|
|
|
|
# kwargs={
|
|
|
|
|
# 'process_group': self.shard_config.tensor_parallel_process_group,
|
|
|
|
|
# }
|
|
|
|
|
# ),
|
|
|
|
|
# SubModuleReplacementDescription(
|
|
|
|
|
# suffix="self_attn.v_proj",
|
|
|
|
|
# target_module=Linear1D_Col,
|
|
|
|
|
# kwargs={
|
|
|
|
|
# 'process_group': self.shard_config.tensor_parallel_process_group,
|
|
|
|
|
# }
|
|
|
|
|
# ),
|
|
|
|
|
# SubModuleReplacementDescription(
|
|
|
|
|
# suffix="self_attn.o_proj",
|
|
|
|
|
# target_module=Linear1D_Row,
|
|
|
|
|
# kwargs={
|
|
|
|
|
# 'process_group': self.shard_config.tensor_parallel_process_group,
|
|
|
|
|
# }
|
|
|
|
|
# ),
|
|
|
|
|
# # SubModuleReplacementDescription(
|
|
|
|
|
# # suffix="mlp.gate_proj",
|
|
|
|
|
# # target_module=Linear1D_Col,
|
|
|
|
|
# # ),
|
|
|
|
|
# # SubModuleReplacementDescription(
|
|
|
|
|
# # suffix="mlp.up_proj",
|
|
|
|
|
# # target_module=Linear1D_Col,
|
|
|
|
|
# # ),
|
|
|
|
|
# # SubModuleReplacementDescription(
|
|
|
|
|
# # suffix="mlp.down_proj",
|
|
|
|
|
# # target_module=Linear1D_Row,
|
|
|
|
|
# # ),
|
|
|
|
|
# ],
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
if getattr(self.shard_config, "ep_group", None) is None:
|
|
|
|
|
raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
|
|
|
|
|
|
|
|
|
|
# expert parallel
|
|
|
|
|
self.append_or_create_submodule_replacement(
|
|
|
|
|
description=[
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
suffix="block_sparse_moe",
|
|
|
|
|
target_module=EPMixtralSparseMoeBlock,
|
|
|
|
|
kwargs={"ep_group": self.shard_config.ep_group},
|
|
|
|
|
)
|
|
|
|
|
],
|
|
|
|
|
policy=policy,
|
|
|
|
|
target_key=MixtralDecoderLayer,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# optimization configuration
|
|
|
|
|
if self.shard_config.enable_fused_normalization:
|
|
|
|
@ -81,7 +143,7 @@ class MixtralPolicy(Policy):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.shard_config.enable_flash_attention:
|
|
|
|
|
raise NotImplementedError("Flash attention has already been replaced in mixtral.")
|
|
|
|
|
warnings.warn("Flash attention is natively supported in transformers, will ignore the flag.")
|
|
|
|
|
|
|
|
|
|
return policy
|
|
|
|
|
|
|
|
|
@ -150,7 +212,7 @@ class MixtralModelPolicy(MixtralPolicy):
|
|
|
|
|
return held_layers
|
|
|
|
|
|
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
|
|
|
"""No shared params in llama model"""
|
|
|
|
|
"""No shared params in mixtral model"""
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -206,3 +268,40 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
|
|
|
|
}
|
|
|
|
|
]
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MixtralForSequenceClassificationPolicy(MixtralPolicy):
|
|
|
|
|
def module_policy(self):
|
|
|
|
|
from transformers import MixtralForSequenceClassification
|
|
|
|
|
|
|
|
|
|
policy = super().module_policy()
|
|
|
|
|
|
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
|
# add a new item for sequence classification
|
|
|
|
|
new_item = {
|
|
|
|
|
MixtralForSequenceClassification: ModulePolicyDescription(
|
|
|
|
|
sub_module_replacement=[
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
|
|
|
|
)
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
}
|
|
|
|
|
policy.update(new_item)
|
|
|
|
|
|
|
|
|
|
if self.pipeline_stage_manager:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
return policy
|
|
|
|
|
|
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
|
|
|
"""Get pipeline layers for current stage."""
|
|
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
|
held_layers = super().get_held_layers()
|
|
|
|
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
|
|
|
|
held_layers.append(self.model.score)
|
|
|
|
|
return held_layers
|
|
|
|
|
|
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
|
|
|
"""No shared params in llama for sequence classification model"""
|
|
|
|
|
return []
|
|
|
|
|