mirror of https://github.com/InternLM/InternLM
add moe_type assert
parent
7cec7e985f
commit
13f3eeb994
|
@ -44,7 +44,7 @@ class MoE(torch.nn.Module):
|
|||
|
||||
super().__init__()
|
||||
|
||||
moe_impl = self.get_moe(getattr(gpc.config.model, "moe_type", None))
|
||||
moe_impl = self.get_moe_impl(gpc.config.model.moe_type)
|
||||
|
||||
if not hasattr(gpc.config, "moe"):
|
||||
gpc.config.moe = dict()
|
||||
|
@ -74,9 +74,11 @@ class MoE(torch.nn.Module):
|
|||
# coefficient is used for weighted sum of the output of expert and residual mlp
|
||||
self.coefficient = torch.nn.Linear(hidden_size, 2)
|
||||
|
||||
def get_moe(self, moe_type):
|
||||
def get_moe_impl(self, moe_type):
|
||||
if moe_type is None or moe_type == "GShard":
|
||||
return GShardMOELayer
|
||||
else:
|
||||
assert False, "unsupported moe type"
|
||||
|
||||
def forward(self, hidden_states, used_token=None):
|
||||
"""MoE forward
|
||||
|
|
Loading…
Reference in New Issue