use regester to get moe impl

pull/567/head
Wenwen Qu 2024-01-12 13:20:17 +08:00
parent 13f3eeb994
commit 8acf8455fe
3 changed files with 5 additions and 11 deletions

View File

@ -306,7 +306,7 @@ def args_sanity_check():
if "moe_use_residual" not in model: if "moe_use_residual" not in model:
model._add_item("moe_use_residual", False) model._add_item("moe_use_residual", False)
if "moe_type" not in model: if "moe_type" not in model:
model._add_item("moe_type", None) model._add_item("moe_type", "GShard")
# process the parallel config # process the parallel config
if "sequence_parallel" not in gpc.config.parallel: if "sequence_parallel" not in gpc.config.parallel:
gpc.config.parallel._add_item("sequence_parallel", False) gpc.config.parallel._add_item("sequence_parallel", False)

View File

@ -3,8 +3,8 @@ import torch
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.model.linear import FeedForward from internlm.model.linear import FeedForward
from internlm.moe import GShardMOELayer
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.registry import MODEL_INITIALIZER
# global llm logger # global llm logger
logger = get_logger(__file__) logger = get_logger(__file__)
@ -44,12 +44,10 @@ class MoE(torch.nn.Module):
super().__init__() super().__init__()
moe_impl = self.get_moe_impl(gpc.config.model.moe_type)
if not hasattr(gpc.config, "moe"): if not hasattr(gpc.config, "moe"):
gpc.config.moe = dict() gpc.config.moe = dict()
self.moe_layer = moe_impl( self.moe_layer = MODEL_INITIALIZER.get_module(module_name=gpc.config.model.moe_type)(
hidden_size=hidden_size, hidden_size=hidden_size,
num_experts=num_experts, num_experts=num_experts,
ep_group=ep_group, ep_group=ep_group,
@ -74,12 +72,6 @@ class MoE(torch.nn.Module):
# coefficient is used for weighted sum of the output of expert and residual mlp # coefficient is used for weighted sum of the output of expert and residual mlp
self.coefficient = torch.nn.Linear(hidden_size, 2) self.coefficient = torch.nn.Linear(hidden_size, 2)
def get_moe_impl(self, moe_type):
if moe_type is None or moe_type == "GShard":
return GShardMOELayer
else:
assert False, "unsupported moe type"
def forward(self, hidden_states, used_token=None): def forward(self, hidden_states, used_token=None):
"""MoE forward """MoE forward

View File

@ -17,6 +17,7 @@ from internlm.core.context import global_context as gpc
from internlm.model.linear import FeedForward from internlm.model.linear import FeedForward
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.registry import MODEL_INITIALIZER
from .base_moe import BaseMoELayer from .base_moe import BaseMoELayer
from .utils import _AllToAll from .utils import _AllToAll
@ -364,6 +365,7 @@ class TopKGate(Module):
return gate_output return gate_output
@MODEL_INITIALIZER.register_module(module_name="GShard")
class GShardMOELayer(BaseMoELayer): class GShardMOELayer(BaseMoELayer):
"""MOELayer module which implements MixtureOfExperts as described in Gshard_. """MOELayer module which implements MixtureOfExperts as described in Gshard_.
:: ::