diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 63d2bfa..700d0dc 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -164,9 +164,6 @@ class HybridZeroOptimizer(BaseOptimizer): self._param_bcast_sync_handler = param_bcast_sync_handler if self._overlap_sync_param: assert self._param_bcast_sync_handler is not None - self._broadcast_comm_stream = torch.cuda.Stream() - else: - self._broadcast_comm_stream = torch.cuda.current_stream() # iterate over the param group in the optimizer # partition these param groups for data parallel training @@ -648,8 +645,7 @@ class HybridZeroOptimizer(BaseOptimizer): fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] fp16_param.data.copy_(fp32_param) - with torch.cuda.stream(self._broadcast_comm_stream): - self.broadcast_params() + self.broadcast_params() timer("step").stop()