[shardformer] fix the moe (#5883)

pull/5891/head
Wang Binluo 2024-07-03 20:02:19 +08:00 committed by GitHub
parent eb24fcd914
commit 6cd4c32be4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 16 deletions

View File

@ -1,10 +1,18 @@
from .gemini_plugin import GeminiPlugin
from .hybrid_parallel_plugin import HybridParallelPlugin
from .low_level_zero_plugin import LowLevelZeroPlugin
from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin
__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"]
__all__ = [
"Plugin",
"TorchDDPPlugin",
"GeminiPlugin",
"LowLevelZeroPlugin",
"HybridParallelPlugin",
"MoeHybridParallelPlugin",
]
import torch
from packaging import version

View File

@ -40,21 +40,19 @@ class MixtralPolicy(Policy):
if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
if getattr(self.shard_config, "ep_group", None) is None:
raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
kwargs={"ep_group": self.shard_config.ep_group},
)
],
policy=policy,
target_key=MixtralDecoderLayer,
)
if getattr(self.shard_config, "ep_group", None) is not None:
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
kwargs={"ep_group": self.shard_config.ep_group},
)
],
policy=policy,
target_key=MixtralDecoderLayer,
)
# optimization configuration
if self.shard_config.enable_fused_normalization: