From c423f1159b6849dd9cdf1560e13cddde201f2cff Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Tue, 9 Jan 2024 15:56:59 +0800 Subject: [PATCH] add moe_type to model config --- internlm/model/modeling_moe.py | 1 + internlm/model/moe.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index 4e609bc..ff6197d 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -527,6 +527,7 @@ def build_model_with_moe_cfg( moe_drop_tokens: bool = True, # pylint: disable=W0613 moe_use_rts: bool = True, # pylint: disable=W0613 moe_use_residual: bool = False, # pylint: disable=W0613 + moe_type: str = None, # pylint: disable=W0613 ): """ Build model with config. diff --git a/internlm/model/moe.py b/internlm/model/moe.py index ab18f69..ff37a6d 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -39,7 +39,6 @@ class MoE(torch.nn.Module): ep_size=1, device=None, dtype=None, - moe_type: str = None, ): super().__init__() @@ -51,6 +50,8 @@ class MoE(torch.nn.Module): self.num_experts = num_experts self.num_local_experts = num_experts // self.ep_size + moe_type = getattr(gpc.config.model, "moe_type", None) + if moe_type is None or moe_type == "GShard": self.moe_layer = GShardMOELayer( hidden_size,