|
|
|
@ -589,9 +589,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
|
|
|
|
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
|
|
|
|
if not self._overlap_allgather:
|
|
|
|
|
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
|
|
|
|
|
if not tensor_bucket.is_empty():
|
|
|
|
|
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
|
|
|
|
|
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
|
|
|
|
|
if not tensor_bucket.is_empty():
|
|
|
|
|
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
|
|
|
|
|
|
|
|
|
|
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
|
|
|
|
|
r"""
|
|
|
|
|