mirror of https://github.com/InternLM/InternLM
fix: fix the bug to do bcast in a stream (#294)
* fix: fix the bug to do bcast in a stream * fix: fix the bug to do bcast in a stream --------- Co-authored-by: yingtongxiong <974106207@qq.com>pull/293/head
parent
0c276d8de2
commit
0423426c4c
|
@ -135,6 +135,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# self._overlap_communication = overlap_communication
|
||||
self._reduce_bucket_size = reduce_bucket_size
|
||||
|
||||
self._comm_bcast_stream = torch.cuda.Stream()
|
||||
|
||||
# gradient scaler
|
||||
self.grad_scaler = DynamicGradScaler(
|
||||
initial_scale=initial_scale,
|
||||
|
@ -653,6 +655,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||
fp16_param.data.copy_(fp32_param)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
with torch.cuda.stream(self._comm_bcast_stream):
|
||||
self.broadcast_params()
|
||||
|
||||
timer("step").stop()
|
||||
|
|
Loading…
Reference in New Issue