change the order of dp and sp all-reduce

pull/443/head
yingtongxiong 2023-10-25 13:27:47 +08:00
parent 1655a90f34
commit 1bc3c33b75
1 changed files with 2 additions and 3 deletions

View File

@ -318,14 +318,13 @@ class HybridZeroOptimizer(BaseOptimizer):
if self.skip_grad_reduce is False:
reduction_sp_func()
accum_grad_obj.register_hook(reduce_grad_hook)
# if sequence_parallel is True, the grad of norm should be all-reduce across the tp process group
if gpc.config.parallel.sequence_parallel is True:
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
accum_grad_obj_sp = get_grad_accumulate_object(param)
accum_grad_obj_sp.register_hook(reduce_grad_hook_sp)
accum_grad_obj.register_hook(reduce_grad_hook)
_define_and_attach(param, reduce_rank)