mirror of https://github.com/InternLM/InternLM
fix/broadcast should not in commu stream (#276)
* fix/brocast should not in commu stream * fix/brocast should not in commu stream --------- Co-authored-by: yingtongxiong <974106207@qq.com>pull/273/head
parent
5238f15e2d
commit
7f61505fa0
|
@ -164,9 +164,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._param_bcast_sync_handler = param_bcast_sync_handler
|
||||
if self._overlap_sync_param:
|
||||
assert self._param_bcast_sync_handler is not None
|
||||
self._broadcast_comm_stream = torch.cuda.Stream()
|
||||
else:
|
||||
self._broadcast_comm_stream = torch.cuda.current_stream()
|
||||
|
||||
# iterate over the param group in the optimizer
|
||||
# partition these param groups for data parallel training
|
||||
|
@ -648,8 +645,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||
fp16_param.data.copy_(fp32_param)
|
||||
|
||||
with torch.cuda.stream(self._broadcast_comm_stream):
|
||||
self.broadcast_params()
|
||||
self.broadcast_params()
|
||||
|
||||
timer("step").stop()
|
||||
|
||||
|
|
Loading…
Reference in New Issue