Merge remote-tracking branch 'origin/main' into develop

pull/275/head^2
yingtongxiong 2023-09-05 17:50:35 +08:00
commit 0fb8d4141f
1 changed files with 1 additions and 5 deletions

View File

@ -164,9 +164,6 @@ class HybridZeroOptimizer(BaseOptimizer):
self._param_bcast_sync_handler = param_bcast_sync_handler
if self._overlap_sync_param:
assert self._param_bcast_sync_handler is not None
self._broadcast_comm_stream = torch.cuda.Stream()
else:
self._broadcast_comm_stream = torch.cuda.current_stream()
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
@ -648,8 +645,7 @@ class HybridZeroOptimizer(BaseOptimizer):
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param)
with torch.cuda.stream(self._broadcast_comm_stream):
self.broadcast_params()
self.broadcast_params()
timer("step").stop()