diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index 4e609bc..ff6197d 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -527,6 +527,7 @@ def build_model_with_moe_cfg( moe_drop_tokens: bool = True, # pylint: disable=W0613 moe_use_rts: bool = True, # pylint: disable=W0613 moe_use_residual: bool = False, # pylint: disable=W0613 + moe_type: str = None, # pylint: disable=W0613 ): """ Build model with config. diff --git a/internlm/model/moe.py b/internlm/model/moe.py index ab18f69..ff37a6d 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -39,7 +39,6 @@ class MoE(torch.nn.Module): ep_size=1, device=None, dtype=None, - moe_type: str = None, ): super().__init__() @@ -51,6 +50,8 @@ class MoE(torch.nn.Module): self.num_experts = num_experts self.num_local_experts = num_experts // self.ep_size + moe_type = getattr(gpc.config.model, "moe_type", None) + if moe_type is None or moe_type == "GShard": self.moe_layer = GShardMOELayer( hidden_size,