diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index 71d54c298..49a9645bc 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -52,7 +52,7 @@ class Top1Router(nn.Module): def forward(self, inputs: torch.Tensor, cuda_mode: bool = False): - if self.noisy_func is not None: + if self.noisy_func is not None and self.training: inputs_noisy = self.noisy_func(inputs) else: inputs_noisy = inputs @@ -126,7 +126,7 @@ class Top2Router(nn.Module): def forward(self, inputs: torch.Tensor, cuda_mode: bool = False): # inputs: [s, h] - if self.noisy_func is not None: + if self.noisy_func is not None and self.training: inputs = self.noisy_func(inputs) logits = autocast_softmax(inputs, dim=-1) # logits: [s, e]