mirror of https://github.com/InternLM/InternLM
use regester to get moe impl
parent
13f3eeb994
commit
8acf8455fe
|
@ -306,7 +306,7 @@ def args_sanity_check():
|
||||||
if "moe_use_residual" not in model:
|
if "moe_use_residual" not in model:
|
||||||
model._add_item("moe_use_residual", False)
|
model._add_item("moe_use_residual", False)
|
||||||
if "moe_type" not in model:
|
if "moe_type" not in model:
|
||||||
model._add_item("moe_type", None)
|
model._add_item("moe_type", "GShard")
|
||||||
# process the parallel config
|
# process the parallel config
|
||||||
if "sequence_parallel" not in gpc.config.parallel:
|
if "sequence_parallel" not in gpc.config.parallel:
|
||||||
gpc.config.parallel._add_item("sequence_parallel", False)
|
gpc.config.parallel._add_item("sequence_parallel", False)
|
||||||
|
|
|
@ -3,8 +3,8 @@ import torch
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.model.linear import FeedForward
|
from internlm.model.linear import FeedForward
|
||||||
from internlm.moe import GShardMOELayer
|
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
|
from internlm.utils.registry import MODEL_INITIALIZER
|
||||||
|
|
||||||
# global llm logger
|
# global llm logger
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
@ -44,12 +44,10 @@ class MoE(torch.nn.Module):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
moe_impl = self.get_moe_impl(gpc.config.model.moe_type)
|
|
||||||
|
|
||||||
if not hasattr(gpc.config, "moe"):
|
if not hasattr(gpc.config, "moe"):
|
||||||
gpc.config.moe = dict()
|
gpc.config.moe = dict()
|
||||||
|
|
||||||
self.moe_layer = moe_impl(
|
self.moe_layer = MODEL_INITIALIZER.get_module(module_name=gpc.config.model.moe_type)(
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
ep_group=ep_group,
|
ep_group=ep_group,
|
||||||
|
@ -74,12 +72,6 @@ class MoE(torch.nn.Module):
|
||||||
# coefficient is used for weighted sum of the output of expert and residual mlp
|
# coefficient is used for weighted sum of the output of expert and residual mlp
|
||||||
self.coefficient = torch.nn.Linear(hidden_size, 2)
|
self.coefficient = torch.nn.Linear(hidden_size, 2)
|
||||||
|
|
||||||
def get_moe_impl(self, moe_type):
|
|
||||||
if moe_type is None or moe_type == "GShard":
|
|
||||||
return GShardMOELayer
|
|
||||||
else:
|
|
||||||
assert False, "unsupported moe type"
|
|
||||||
|
|
||||||
def forward(self, hidden_states, used_token=None):
|
def forward(self, hidden_states, used_token=None):
|
||||||
"""MoE forward
|
"""MoE forward
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ from internlm.core.context import global_context as gpc
|
||||||
from internlm.model.linear import FeedForward
|
from internlm.model.linear import FeedForward
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
|
from internlm.utils.registry import MODEL_INITIALIZER
|
||||||
|
|
||||||
from .base_moe import BaseMoELayer
|
from .base_moe import BaseMoELayer
|
||||||
from .utils import _AllToAll
|
from .utils import _AllToAll
|
||||||
|
@ -364,6 +365,7 @@ class TopKGate(Module):
|
||||||
return gate_output
|
return gate_output
|
||||||
|
|
||||||
|
|
||||||
|
@MODEL_INITIALIZER.register_module(module_name="GShard")
|
||||||
class GShardMOELayer(BaseMoELayer):
|
class GShardMOELayer(BaseMoELayer):
|
||||||
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
|
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
|
||||||
::
|
::
|
||||||
|
|
Loading…
Reference in New Issue