diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 97004eb..5004f8a 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -710,6 +710,9 @@ class HybridZeroOptimizer(BaseOptimizer): with torch.cuda.stream(self._comm_bcast_stream): self.broadcast_params() + if not self._overlap_sync_param: + torch.cuda.synchronize() + timer("step").stop() # update gradients may not be needed here, because the sync_params function is used in initialization,