removed noisy function during evaluation of MoE router (#419)

pull/421/head
HELSON 2022-03-15 12:06:09 +08:00 committed by GitHub
parent adebb3e041
commit 3f70a2b12f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -52,7 +52,7 @@ class Top1Router(nn.Module):
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
if self.noisy_func is not None:
if self.noisy_func is not None and self.training:
inputs_noisy = self.noisy_func(inputs)
else:
inputs_noisy = inputs
@ -126,7 +126,7 @@ class Top2Router(nn.Module):
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
# inputs: [s, h]
if self.noisy_func is not None:
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
logits = autocast_softmax(inputs, dim=-1) # logits: [s, e]