From 7f61505fa014d909c3382576cef808839ba080be Mon Sep 17 00:00:00 2001 From: Sun Peng Date: Tue, 5 Sep 2023 17:47:50 +0800 Subject: [PATCH] fix/broadcast should not in commu stream (#276) * fix/brocast should not in commu stream * fix/brocast should not in commu stream --------- Co-authored-by: yingtongxiong <974106207@qq.com> --- internlm/solver/optimizer/hybrid_zero_optim.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 63d2bfa..700d0dc 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -164,9 +164,6 @@ class HybridZeroOptimizer(BaseOptimizer): self._param_bcast_sync_handler = param_bcast_sync_handler if self._overlap_sync_param: assert self._param_bcast_sync_handler is not None - self._broadcast_comm_stream = torch.cuda.Stream() - else: - self._broadcast_comm_stream = torch.cuda.current_stream() # iterate over the param group in the optimizer # partition these param groups for data parallel training @@ -648,8 +645,7 @@ class HybridZeroOptimizer(BaseOptimizer): fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] fp16_param.data.copy_(fp32_param) - with torch.cuda.stream(self._broadcast_comm_stream): - self.broadcast_params() + self.broadcast_params() timer("step").stop()