fix the merge

pull/6023/head
wangbluo 2024-08-19 08:07:51 +00:00
parent 1a5847e6d1
commit 3353042525
1 changed files with 4 additions and 3 deletions

View File

@ -588,9 +588,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.pg_to_tensor_bucket[pg].all_gather(pg, fp8_communication=self._fp8_communication)
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg)
if not self._overlap_allgather:
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""