refactor code

pull/567/head
Wenwen Qu 2024-01-08 16:23:53 +08:00
parent 41f8283a3e
commit f5226b5152
1 changed files with 1 additions and 1 deletions

View File

@ -31,5 +31,5 @@ class BaseMoELayer(Base):
self.ep_group = ep_group self.ep_group = ep_group
self.ep_size = ep_size self.ep_size = ep_size
self.num_local_experts = num_local_experts self.num_local_experts = num_local_experts
self.l_aux = torch.zeros(1, device=torch.cuda.current_device()) self.l_aux = torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
self.exp_counts = None self.exp_counts = None