diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index b0463a3..741538e 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -306,7 +306,7 @@ def args_sanity_check(): if "moe_use_residual" not in model: model._add_item("moe_use_residual", False) if "moe_type" not in model: - model._add_item("moe_type", None) + model._add_item("moe_type", "GShard") # process the parallel config if "sequence_parallel" not in gpc.config.parallel: gpc.config.parallel._add_item("sequence_parallel", False) diff --git a/internlm/model/moe.py b/internlm/model/moe.py index 992d422..e1e2edf 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -3,8 +3,8 @@ import torch from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.model.linear import FeedForward -from internlm.moe import GShardMOELayer from internlm.utils.logger import get_logger +from internlm.utils.registry import MODEL_INITIALIZER # global llm logger logger = get_logger(__file__) @@ -44,12 +44,10 @@ class MoE(torch.nn.Module): super().__init__() - moe_impl = self.get_moe_impl(gpc.config.model.moe_type) - if not hasattr(gpc.config, "moe"): gpc.config.moe = dict() - self.moe_layer = moe_impl( + self.moe_layer = MODEL_INITIALIZER.get_module(module_name=gpc.config.model.moe_type)( hidden_size=hidden_size, num_experts=num_experts, ep_group=ep_group, @@ -74,12 +72,6 @@ class MoE(torch.nn.Module): # coefficient is used for weighted sum of the output of expert and residual mlp self.coefficient = torch.nn.Linear(hidden_size, 2) - def get_moe_impl(self, moe_type): - if moe_type is None or moe_type == "GShard": - return GShardMOELayer - else: - assert False, "unsupported moe type" - def forward(self, hidden_states, used_token=None): """MoE forward diff --git a/internlm/moe/sharded_moe.py b/internlm/moe/sharded_moe.py index b71c1ee..8cbcaf6 100644 --- a/internlm/moe/sharded_moe.py +++ b/internlm/moe/sharded_moe.py @@ -17,6 +17,7 @@ from internlm.core.context import global_context as gpc from internlm.model.linear import FeedForward from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.utils.registry import MODEL_INITIALIZER from .base_moe import BaseMoELayer from .utils import _AllToAll @@ -364,6 +365,7 @@ class TopKGate(Module): return gate_output +@MODEL_INITIALIZER.register_module(module_name="GShard") class GShardMOELayer(BaseMoELayer): """MOELayer module which implements MixtureOfExperts as described in Gshard_. ::