mirror of https://github.com/InternLM/InternLM
add moe_type assert
parent
7cec7e985f
commit
13f3eeb994
|
@ -44,7 +44,7 @@ class MoE(torch.nn.Module):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
moe_impl = self.get_moe(getattr(gpc.config.model, "moe_type", None))
|
moe_impl = self.get_moe_impl(gpc.config.model.moe_type)
|
||||||
|
|
||||||
if not hasattr(gpc.config, "moe"):
|
if not hasattr(gpc.config, "moe"):
|
||||||
gpc.config.moe = dict()
|
gpc.config.moe = dict()
|
||||||
|
@ -74,9 +74,11 @@ class MoE(torch.nn.Module):
|
||||||
# coefficient is used for weighted sum of the output of expert and residual mlp
|
# coefficient is used for weighted sum of the output of expert and residual mlp
|
||||||
self.coefficient = torch.nn.Linear(hidden_size, 2)
|
self.coefficient = torch.nn.Linear(hidden_size, 2)
|
||||||
|
|
||||||
def get_moe(self, moe_type):
|
def get_moe_impl(self, moe_type):
|
||||||
if moe_type is None or moe_type == "GShard":
|
if moe_type is None or moe_type == "GShard":
|
||||||
return GShardMOELayer
|
return GShardMOELayer
|
||||||
|
else:
|
||||||
|
assert False, "unsupported moe type"
|
||||||
|
|
||||||
def forward(self, hidden_states, used_token=None):
|
def forward(self, hidden_states, used_token=None):
|
||||||
"""MoE forward
|
"""MoE forward
|
||||||
|
|
Loading…
Reference in New Issue