Browse Source

removed noisy function during evaluation of MoE router (#419)

pull/421/head
HELSON 3 years ago committed by GitHub
parent
commit
3f70a2b12f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      colossalai/nn/layer/moe/layers.py

4
colossalai/nn/layer/moe/layers.py

@ -52,7 +52,7 @@ class Top1Router(nn.Module):
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
if self.noisy_func is not None:
if self.noisy_func is not None and self.training:
inputs_noisy = self.noisy_func(inputs)
else:
inputs_noisy = inputs
@ -126,7 +126,7 @@ class Top2Router(nn.Module):
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
# inputs: [s, h]
if self.noisy_func is not None:
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
logits = autocast_softmax(inputs, dim=-1) # logits: [s, e]

Loading…
Cancel
Save