add moe_type to model config

pull/567/head
Wenwen Qu 2024-01-09 15:56:59 +08:00
parent dcfdab6aaf
commit c423f1159b
2 changed files with 3 additions and 1 deletions

View File

@ -527,6 +527,7 @@ def build_model_with_moe_cfg(
moe_drop_tokens: bool = True, # pylint: disable=W0613
moe_use_rts: bool = True, # pylint: disable=W0613
moe_use_residual: bool = False, # pylint: disable=W0613
moe_type: str = None, # pylint: disable=W0613
):
"""
Build model with config.

View File

@ -39,7 +39,6 @@ class MoE(torch.nn.Module):
ep_size=1,
device=None,
dtype=None,
moe_type: str = None,
):
super().__init__()
@ -51,6 +50,8 @@ class MoE(torch.nn.Module):
self.num_experts = num_experts
self.num_local_experts = num_experts // self.ep_size
moe_type = getattr(gpc.config.model, "moe_type", None)
if moe_type is None or moe_type == "GShard":
self.moe_layer = GShardMOELayer(
hidden_size,