mirror of https://github.com/InternLM/InternLM
Merge remote-tracking branch 'origin/main' into develop
commit
0fb8d4141f
|
@ -164,9 +164,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._param_bcast_sync_handler = param_bcast_sync_handler
|
||||
if self._overlap_sync_param:
|
||||
assert self._param_bcast_sync_handler is not None
|
||||
self._broadcast_comm_stream = torch.cuda.Stream()
|
||||
else:
|
||||
self._broadcast_comm_stream = torch.cuda.current_stream()
|
||||
|
||||
# iterate over the param group in the optimizer
|
||||
# partition these param groups for data parallel training
|
||||
|
@ -648,8 +645,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||
fp16_param.data.copy_(fp32_param)
|
||||
|
||||
with torch.cuda.stream(self._broadcast_comm_stream):
|
||||
self.broadcast_params()
|
||||
self.broadcast_params()
|
||||
|
||||
timer("step").stop()
|
||||
|
||||
|
|
Loading…
Reference in New Issue