mirror of https://github.com/InternLM/InternLM
add moe_type to model config
parent
dcfdab6aaf
commit
c423f1159b
|
@ -527,6 +527,7 @@ def build_model_with_moe_cfg(
|
|||
moe_drop_tokens: bool = True, # pylint: disable=W0613
|
||||
moe_use_rts: bool = True, # pylint: disable=W0613
|
||||
moe_use_residual: bool = False, # pylint: disable=W0613
|
||||
moe_type: str = None, # pylint: disable=W0613
|
||||
):
|
||||
"""
|
||||
Build model with config.
|
||||
|
|
|
@ -39,7 +39,6 @@ class MoE(torch.nn.Module):
|
|||
ep_size=1,
|
||||
device=None,
|
||||
dtype=None,
|
||||
moe_type: str = None,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
@ -51,6 +50,8 @@ class MoE(torch.nn.Module):
|
|||
self.num_experts = num_experts
|
||||
self.num_local_experts = num_experts // self.ep_size
|
||||
|
||||
moe_type = getattr(gpc.config.model, "moe_type", None)
|
||||
|
||||
if moe_type is None or moe_type == "GShard":
|
||||
self.moe_layer = GShardMOELayer(
|
||||
hidden_size,
|
||||
|
|
Loading…
Reference in New Issue