diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index e0ed687..70c63a0 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -135,6 +135,8 @@ class HybridZeroOptimizer(BaseOptimizer): # self._overlap_communication = overlap_communication self._reduce_bucket_size = reduce_bucket_size + self._comm_bcast_stream = torch.cuda.Stream() + # gradient scaler self.grad_scaler = DynamicGradScaler( initial_scale=initial_scale, @@ -653,7 +655,9 @@ class HybridZeroOptimizer(BaseOptimizer): fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] fp16_param.data.copy_(fp32_param) - self.broadcast_params() + torch.cuda.synchronize() + with torch.cuda.stream(self._comm_bcast_stream): + self.broadcast_params() timer("step").stop()