diff --git a/internlm/model/moe.py b/internlm/model/moe.py index e4902c8..992d422 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -44,7 +44,7 @@ class MoE(torch.nn.Module): super().__init__() - moe_impl = self.get_moe(getattr(gpc.config.model, "moe_type", None)) + moe_impl = self.get_moe_impl(gpc.config.model.moe_type) if not hasattr(gpc.config, "moe"): gpc.config.moe = dict() @@ -74,9 +74,11 @@ class MoE(torch.nn.Module): # coefficient is used for weighted sum of the output of expert and residual mlp self.coefficient = torch.nn.Linear(hidden_size, 2) - def get_moe(self, moe_type): + def get_moe_impl(self, moe_type): if moe_type is None or moe_type == "GShard": return GShardMOELayer + else: + assert False, "unsupported moe type" def forward(self, hidden_states, used_token=None): """MoE forward