mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] fix the moe (#5883)
parent
eb24fcd914
commit
6cd4c32be4
|
@ -1,10 +1,18 @@
|
||||||
from .gemini_plugin import GeminiPlugin
|
from .gemini_plugin import GeminiPlugin
|
||||||
from .hybrid_parallel_plugin import HybridParallelPlugin
|
from .hybrid_parallel_plugin import HybridParallelPlugin
|
||||||
from .low_level_zero_plugin import LowLevelZeroPlugin
|
from .low_level_zero_plugin import LowLevelZeroPlugin
|
||||||
|
from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
from .plugin_base import Plugin
|
from .plugin_base import Plugin
|
||||||
from .torch_ddp_plugin import TorchDDPPlugin
|
from .torch_ddp_plugin import TorchDDPPlugin
|
||||||
|
|
||||||
__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"]
|
__all__ = [
|
||||||
|
"Plugin",
|
||||||
|
"TorchDDPPlugin",
|
||||||
|
"GeminiPlugin",
|
||||||
|
"LowLevelZeroPlugin",
|
||||||
|
"HybridParallelPlugin",
|
||||||
|
"MoeHybridParallelPlugin",
|
||||||
|
]
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
|
@ -40,21 +40,19 @@ class MixtralPolicy(Policy):
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
|
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
|
||||||
if getattr(self.shard_config, "ep_group", None) is None:
|
if getattr(self.shard_config, "ep_group", None) is not None:
|
||||||
raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
|
# expert parallel
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
# expert parallel
|
description=[
|
||||||
self.append_or_create_submodule_replacement(
|
SubModuleReplacementDescription(
|
||||||
description=[
|
suffix="block_sparse_moe",
|
||||||
SubModuleReplacementDescription(
|
target_module=EPMixtralSparseMoeBlock,
|
||||||
suffix="block_sparse_moe",
|
kwargs={"ep_group": self.shard_config.ep_group},
|
||||||
target_module=EPMixtralSparseMoeBlock,
|
)
|
||||||
kwargs={"ep_group": self.shard_config.ep_group},
|
],
|
||||||
)
|
policy=policy,
|
||||||
],
|
target_key=MixtralDecoderLayer,
|
||||||
policy=policy,
|
)
|
||||||
target_key=MixtralDecoderLayer,
|
|
||||||
)
|
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
|
Loading…
Reference in New Issue