add moe_type to model config

pull/567/head
Wenwen Qu 2024-01-09 15:56:59 +08:00
parent dcfdab6aaf
commit c423f1159b
2 changed files with 3 additions and 1 deletions

View File

@ -527,6 +527,7 @@ def build_model_with_moe_cfg(
moe_drop_tokens: bool = True, # pylint: disable=W0613 moe_drop_tokens: bool = True, # pylint: disable=W0613
moe_use_rts: bool = True, # pylint: disable=W0613 moe_use_rts: bool = True, # pylint: disable=W0613
moe_use_residual: bool = False, # pylint: disable=W0613 moe_use_residual: bool = False, # pylint: disable=W0613
moe_type: str = None, # pylint: disable=W0613
): ):
""" """
Build model with config. Build model with config.

View File

@ -39,7 +39,6 @@ class MoE(torch.nn.Module):
ep_size=1, ep_size=1,
device=None, device=None,
dtype=None, dtype=None,
moe_type: str = None,
): ):
super().__init__() super().__init__()
@ -51,6 +50,8 @@ class MoE(torch.nn.Module):
self.num_experts = num_experts self.num_experts = num_experts
self.num_local_experts = num_experts // self.ep_size self.num_local_experts = num_experts // self.ep_size
moe_type = getattr(gpc.config.model, "moe_type", None)
if moe_type is None or moe_type == "GShard": if moe_type is None or moe_type == "GShard":
self.moe_layer = GShardMOELayer( self.moe_layer = GShardMOELayer(
hidden_size, hidden_size,