fix group_norms computing in hybrid_zero_optim

pull/182/head
Wenwen Qu 2023-08-31 18:46:13 +08:00
parent 2ad5f512b5
commit 7ca5da27e8
1 changed files with 4 additions and 1 deletions

View File

@ -574,7 +574,10 @@ class HybridZeroOptimizer(BaseOptimizer):
# compute norm for gradients in the before bucket
groups_norms = []
for group_id in range(self.num_param_groups):
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
if self._is_moe_group(self.optim.param_groups[group_id]):
groups_norms.append([])
else:
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
# clear reduced grads
if self._overlap_sync_grad: