[shardformer] fix the moe (#5883)

pull/5891/head
Wang Binluo 2024-07-03 20:02:19 +08:00 committed by GitHub
parent eb24fcd914
commit 6cd4c32be4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 16 deletions

View File

@ -1,10 +1,18 @@
from .gemini_plugin import GeminiPlugin from .gemini_plugin import GeminiPlugin
from .hybrid_parallel_plugin import HybridParallelPlugin from .hybrid_parallel_plugin import HybridParallelPlugin
from .low_level_zero_plugin import LowLevelZeroPlugin from .low_level_zero_plugin import LowLevelZeroPlugin
from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from .plugin_base import Plugin from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin from .torch_ddp_plugin import TorchDDPPlugin
__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"] __all__ = [
"Plugin",
"TorchDDPPlugin",
"GeminiPlugin",
"LowLevelZeroPlugin",
"HybridParallelPlugin",
"MoeHybridParallelPlugin",
]
import torch import torch
from packaging import version from packaging import version

View File

@ -40,9 +40,7 @@ class MixtralPolicy(Policy):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.") raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
if getattr(self.shard_config, "ep_group", None) is None: if getattr(self.shard_config, "ep_group", None) is not None:
raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
# expert parallel # expert parallel
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[