add moe_type assert

pull/567/head
Wenwen Qu 2024-01-10 17:13:21 +08:00
parent 7cec7e985f
commit 13f3eeb994
1 changed files with 4 additions and 2 deletions

View File

@ -44,7 +44,7 @@ class MoE(torch.nn.Module):
super().__init__()
moe_impl = self.get_moe(getattr(gpc.config.model, "moe_type", None))
moe_impl = self.get_moe_impl(gpc.config.model.moe_type)
if not hasattr(gpc.config, "moe"):
gpc.config.moe = dict()
@ -74,9 +74,11 @@ class MoE(torch.nn.Module):
# coefficient is used for weighted sum of the output of expert and residual mlp
self.coefficient = torch.nn.Linear(hidden_size, 2)
def get_moe(self, moe_type):
def get_moe_impl(self, moe_type):
if moe_type is None or moe_type == "GShard":
return GShardMOELayer
else:
assert False, "unsupported moe type"
def forward(self, hidden_states, used_token=None):
"""MoE forward