mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] fix the moe (#5883)
parent
eb24fcd914
commit
6cd4c32be4
|
@ -1,10 +1,18 @@
|
|||
from .gemini_plugin import GeminiPlugin
|
||||
from .hybrid_parallel_plugin import HybridParallelPlugin
|
||||
from .low_level_zero_plugin import LowLevelZeroPlugin
|
||||
from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from .plugin_base import Plugin
|
||||
from .torch_ddp_plugin import TorchDDPPlugin
|
||||
|
||||
__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"]
|
||||
__all__ = [
|
||||
"Plugin",
|
||||
"TorchDDPPlugin",
|
||||
"GeminiPlugin",
|
||||
"LowLevelZeroPlugin",
|
||||
"HybridParallelPlugin",
|
||||
"MoeHybridParallelPlugin",
|
||||
]
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
|
|
@ -40,21 +40,19 @@ class MixtralPolicy(Policy):
|
|||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
|
||||
if getattr(self.shard_config, "ep_group", None) is None:
|
||||
raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
|
||||
|
||||
# expert parallel
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="block_sparse_moe",
|
||||
target_module=EPMixtralSparseMoeBlock,
|
||||
kwargs={"ep_group": self.shard_config.ep_group},
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=MixtralDecoderLayer,
|
||||
)
|
||||
if getattr(self.shard_config, "ep_group", None) is not None:
|
||||
# expert parallel
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="block_sparse_moe",
|
||||
target_module=EPMixtralSparseMoeBlock,
|
||||
kwargs={"ep_group": self.shard_config.ep_group},
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=MixtralDecoderLayer,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
|
|
Loading…
Reference in New Issue