From 13f3eeb994cf22d59e612537f7fb848e94209369 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Wed, 10 Jan 2024 17:13:21 +0800 Subject: [PATCH] add moe_type assert --- internlm/model/moe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/internlm/model/moe.py b/internlm/model/moe.py index e4902c8..992d422 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -44,7 +44,7 @@ class MoE(torch.nn.Module): super().__init__() - moe_impl = self.get_moe(getattr(gpc.config.model, "moe_type", None)) + moe_impl = self.get_moe_impl(gpc.config.model.moe_type) if not hasattr(gpc.config, "moe"): gpc.config.moe = dict() @@ -74,9 +74,11 @@ class MoE(torch.nn.Module): # coefficient is used for weighted sum of the output of expert and residual mlp self.coefficient = torch.nn.Linear(hidden_size, 2) - def get_moe(self, moe_type): + def get_moe_impl(self, moe_type): if moe_type is None or moe_type == "GShard": return GShardMOELayer + else: + assert False, "unsupported moe type" def forward(self, hidden_states, used_token=None): """MoE forward