mirror of https://github.com/InternLM/InternLM
fix: fix the bug to do bcast in a stream (#294)
* fix: fix the bug to do bcast in a stream * fix: fix the bug to do bcast in a stream --------- Co-authored-by: yingtongxiong <974106207@qq.com>pull/293/head
parent
0c276d8de2
commit
0423426c4c
|
@ -135,6 +135,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# self._overlap_communication = overlap_communication
|
# self._overlap_communication = overlap_communication
|
||||||
self._reduce_bucket_size = reduce_bucket_size
|
self._reduce_bucket_size = reduce_bucket_size
|
||||||
|
|
||||||
|
self._comm_bcast_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
# gradient scaler
|
# gradient scaler
|
||||||
self.grad_scaler = DynamicGradScaler(
|
self.grad_scaler = DynamicGradScaler(
|
||||||
initial_scale=initial_scale,
|
initial_scale=initial_scale,
|
||||||
|
@ -653,6 +655,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||||
fp16_param.data.copy_(fp32_param)
|
fp16_param.data.copy_(fp32_param)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
with torch.cuda.stream(self._comm_bcast_stream):
|
||||||
self.broadcast_params()
|
self.broadcast_params()
|
||||||
|
|
||||||
timer("step").stop()
|
timer("step").stop()
|
||||||
|
|
Loading…
Reference in New Issue