mirror of https://github.com/InternLM/InternLM
add moe_type to model config
parent
dcfdab6aaf
commit
c423f1159b
|
@ -527,6 +527,7 @@ def build_model_with_moe_cfg(
|
||||||
moe_drop_tokens: bool = True, # pylint: disable=W0613
|
moe_drop_tokens: bool = True, # pylint: disable=W0613
|
||||||
moe_use_rts: bool = True, # pylint: disable=W0613
|
moe_use_rts: bool = True, # pylint: disable=W0613
|
||||||
moe_use_residual: bool = False, # pylint: disable=W0613
|
moe_use_residual: bool = False, # pylint: disable=W0613
|
||||||
|
moe_type: str = None, # pylint: disable=W0613
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Build model with config.
|
Build model with config.
|
||||||
|
|
|
@ -39,7 +39,6 @@ class MoE(torch.nn.Module):
|
||||||
ep_size=1,
|
ep_size=1,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
moe_type: str = None,
|
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -51,6 +50,8 @@ class MoE(torch.nn.Module):
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.num_local_experts = num_experts // self.ep_size
|
self.num_local_experts = num_experts // self.ep_size
|
||||||
|
|
||||||
|
moe_type = getattr(gpc.config.model, "moe_type", None)
|
||||||
|
|
||||||
if moe_type is None or moe_type == "GShard":
|
if moe_type is None or moe_type == "GShard":
|
||||||
self.moe_layer = GShardMOELayer(
|
self.moe_layer = GShardMOELayer(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
|
Loading…
Reference in New Issue