fix/broadcast should not in commu stream (#276)

* fix/brocast should not in commu stream

* fix/brocast should not in commu stream

---------

Co-authored-by: yingtongxiong <974106207@qq.com>
pull/273/head
Sun Peng 2023-09-05 17:47:50 +08:00 committed by GitHub
parent 5238f15e2d
commit 7f61505fa0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 5 deletions

View File

@ -164,9 +164,6 @@ class HybridZeroOptimizer(BaseOptimizer):
self._param_bcast_sync_handler = param_bcast_sync_handler
if self._overlap_sync_param:
assert self._param_bcast_sync_handler is not None
self._broadcast_comm_stream = torch.cuda.Stream()
else:
self._broadcast_comm_stream = torch.cuda.current_stream()
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
@ -648,8 +645,7 @@ class HybridZeroOptimizer(BaseOptimizer):
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param)
with torch.cuda.stream(self._broadcast_comm_stream):
self.broadcast_params()
self.broadcast_params()
timer("step").stop()