fix: fix the bug to do bcast in a stream (#294)

* fix: fix the bug to do bcast in a stream

* fix: fix the bug to do bcast in a stream

---------

Co-authored-by: yingtongxiong <974106207@qq.com>
pull/293/head
Sun Peng 2023-09-08 13:53:40 +08:00 committed by GitHub
parent 0c276d8de2
commit 0423426c4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 1 deletions

View File

@ -135,6 +135,8 @@ class HybridZeroOptimizer(BaseOptimizer):
# self._overlap_communication = overlap_communication # self._overlap_communication = overlap_communication
self._reduce_bucket_size = reduce_bucket_size self._reduce_bucket_size = reduce_bucket_size
self._comm_bcast_stream = torch.cuda.Stream()
# gradient scaler # gradient scaler
self.grad_scaler = DynamicGradScaler( self.grad_scaler = DynamicGradScaler(
initial_scale=initial_scale, initial_scale=initial_scale,
@ -653,6 +655,8 @@ class HybridZeroOptimizer(BaseOptimizer):
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param) fp16_param.data.copy_(fp32_param)
torch.cuda.synchronize()
with torch.cuda.stream(self._comm_bcast_stream):
self.broadcast_params() self.broadcast_params()
timer("step").stop() timer("step").stop()