mirror of https://github.com/InternLM/InternLM
use regester to get moe impl
parent
13f3eeb994
commit
8acf8455fe
|
@ -306,7 +306,7 @@ def args_sanity_check():
|
|||
if "moe_use_residual" not in model:
|
||||
model._add_item("moe_use_residual", False)
|
||||
if "moe_type" not in model:
|
||||
model._add_item("moe_type", None)
|
||||
model._add_item("moe_type", "GShard")
|
||||
# process the parallel config
|
||||
if "sequence_parallel" not in gpc.config.parallel:
|
||||
gpc.config.parallel._add_item("sequence_parallel", False)
|
||||
|
|
|
@ -3,8 +3,8 @@ import torch
|
|||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.linear import FeedForward
|
||||
from internlm.moe import GShardMOELayer
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
|
||||
# global llm logger
|
||||
logger = get_logger(__file__)
|
||||
|
@ -44,12 +44,10 @@ class MoE(torch.nn.Module):
|
|||
|
||||
super().__init__()
|
||||
|
||||
moe_impl = self.get_moe_impl(gpc.config.model.moe_type)
|
||||
|
||||
if not hasattr(gpc.config, "moe"):
|
||||
gpc.config.moe = dict()
|
||||
|
||||
self.moe_layer = moe_impl(
|
||||
self.moe_layer = MODEL_INITIALIZER.get_module(module_name=gpc.config.model.moe_type)(
|
||||
hidden_size=hidden_size,
|
||||
num_experts=num_experts,
|
||||
ep_group=ep_group,
|
||||
|
@ -74,12 +72,6 @@ class MoE(torch.nn.Module):
|
|||
# coefficient is used for weighted sum of the output of expert and residual mlp
|
||||
self.coefficient = torch.nn.Linear(hidden_size, 2)
|
||||
|
||||
def get_moe_impl(self, moe_type):
|
||||
if moe_type is None or moe_type == "GShard":
|
||||
return GShardMOELayer
|
||||
else:
|
||||
assert False, "unsupported moe type"
|
||||
|
||||
def forward(self, hidden_states, used_token=None):
|
||||
"""MoE forward
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ from internlm.core.context import global_context as gpc
|
|||
from internlm.model.linear import FeedForward
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
|
||||
from .base_moe import BaseMoELayer
|
||||
from .utils import _AllToAll
|
||||
|
@ -364,6 +365,7 @@ class TopKGate(Module):
|
|||
return gate_output
|
||||
|
||||
|
||||
@MODEL_INITIALIZER.register_module(module_name="GShard")
|
||||
class GShardMOELayer(BaseMoELayer):
|
||||
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
|
||||
::
|
||||
|
|
Loading…
Reference in New Issue