diff --git a/internlm/moe/base_moe.py b/internlm/moe/base_moe.py index 5f19e07..c97ec83 100644 --- a/internlm/moe/base_moe.py +++ b/internlm/moe/base_moe.py @@ -31,5 +31,5 @@ class BaseMoELayer(Base): self.ep_group = ep_group self.ep_size = ep_size self.num_local_experts = num_local_experts - self.l_aux = torch.zeros(1, device=torch.cuda.current_device()) + self.l_aux = torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) self.exp_counts = None