[shardformer] fix the moe (#5883)

pull/5891/head
Wang Binluo 2024-07-03 20:02:19 +08:00 committed by GitHub
parent eb24fcd914
commit 6cd4c32be4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 16 deletions

View File

@ -1,10 +1,18 @@
from .gemini_plugin import GeminiPlugin from .gemini_plugin import GeminiPlugin
from .hybrid_parallel_plugin import HybridParallelPlugin from .hybrid_parallel_plugin import HybridParallelPlugin
from .low_level_zero_plugin import LowLevelZeroPlugin from .low_level_zero_plugin import LowLevelZeroPlugin
from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from .plugin_base import Plugin from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin from .torch_ddp_plugin import TorchDDPPlugin
__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"] __all__ = [
"Plugin",
"TorchDDPPlugin",
"GeminiPlugin",
"LowLevelZeroPlugin",
"HybridParallelPlugin",
"MoeHybridParallelPlugin",
]
import torch import torch
from packaging import version from packaging import version

View File

@ -40,21 +40,19 @@ class MixtralPolicy(Policy):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.") raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
if getattr(self.shard_config, "ep_group", None) is None: if getattr(self.shard_config, "ep_group", None) is not None:
raise ValueError("You must pass in ep_group via shard_config for expert parallel!") # expert parallel
self.append_or_create_submodule_replacement(
# expert parallel description=[
self.append_or_create_submodule_replacement( SubModuleReplacementDescription(
description=[ suffix="block_sparse_moe",
SubModuleReplacementDescription( target_module=EPMixtralSparseMoeBlock,
suffix="block_sparse_moe", kwargs={"ep_group": self.shard_config.ep_group},
target_module=EPMixtralSparseMoeBlock, )
kwargs={"ep_group": self.shard_config.ep_group}, ],
) policy=policy,
], target_key=MixtralDecoderLayer,
policy=policy, )
target_key=MixtralDecoderLayer,
)
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization: