fix the bug so that the sequence parallel norm is all-reduced when overlap is False

pull/534/head
yingtongxiong 2023-12-11 17:36:33 +08:00
parent 2dbbab7418
commit c7db6db066
1 changed files with 9 additions and 9 deletions

View File

@ -218,10 +218,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
self.skip_grad_reduce = False
# reduction hook is only used if overlapping communication
# if it is stage 1 without overlapping, no hook will be attached
if self._overlap_sync_grad:
self._attach_reduction_hook()
self._attach_reduction_hook()
@property
def zero_local_rank(self):
@ -320,12 +317,15 @@ class HybridZeroOptimizer(BaseOptimizer):
# if sequence_parallel is True,
# the grad of norm should be all-reduce across the tp process group
if gpc.config.parallel.sequence_parallel is True:
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
accum_grad_obj_sp = get_grad_accumulate_object(param)
accum_grad_obj_sp.register_hook(reduce_grad_hook_sp)
if (
gpc.config.parallel.sequence_parallel is True
and hasattr(param, IS_SEQUENCE_PARALLEL)
and getattr(param, IS_SEQUENCE_PARALLEL) is True
):
accum_grad_obj.register_hook(reduce_grad_hook_sp)
accum_grad_obj.register_hook(reduce_grad_hook)
if self._overlap_sync_grad:
accum_grad_obj.register_hook(reduce_grad_hook)
_define_and_attach(param, reduce_rank)