refactor code

pull/182/head
zhanglei 2023-09-22 12:30:02 +08:00
parent 80972ff314
commit 548d1bd7af
1 changed files with 5 additions and 3 deletions

View File

@ -695,9 +695,11 @@ class HybridZeroOptimizer(BaseOptimizer):
# Parameters shared within a TP group, such as norm and moe gate, have precision inconsistency in gradients.
# Therefore, it is recommended to synchronize gradients within the TP group to eliminate accumulated errors.
is_tp_shared_params = (self._is_norm_group(self.optim.param_groups[group_id])
or self._is_gate_group(self.optim.param_groups[group_id]))
if is_tp_shared_params:
is_tp_sync_groups = (
self._is_norm_group(self.optim.param_groups[group_id]),
self._is_gate_group(self.optim.param_groups[group_id]),
)
if any(is_tp_sync_groups):
dist.all_reduce(
flat_fp32_avg_grads,
op=dist.ReduceOp.AVG,