fix the bug so that the sequence parallel norm is all-reduced when overlap is False (#534)

pull/538/head
ytxiong 2023-12-12 16:22:39 +08:00 committed by GitHub
parent d904730be7
commit 432bd5ee9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 9 deletions

View File

@ -219,10 +219,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled. # flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
self.skip_grad_reduce = False self.skip_grad_reduce = False
# reduction hook is only used if overlapping communication self._attach_reduction_hook()
# if it is stage 1 without overlapping, no hook will be attached
if self._overlap_sync_grad:
self._attach_reduction_hook()
@property @property
def zero_local_rank(self): def zero_local_rank(self):
@ -321,12 +318,15 @@ class HybridZeroOptimizer(BaseOptimizer):
# if sequence_parallel is True, # if sequence_parallel is True,
# the grad of norm should be all-reduce across the tp process group # the grad of norm should be all-reduce across the tp process group
if gpc.config.parallel.sequence_parallel is True: if (
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True: gpc.config.parallel.sequence_parallel is True
accum_grad_obj_sp = get_grad_accumulate_object(param) and hasattr(param, IS_SEQUENCE_PARALLEL)
accum_grad_obj_sp.register_hook(reduce_grad_hook_sp) and getattr(param, IS_SEQUENCE_PARALLEL) is True
):
accum_grad_obj.register_hook(reduce_grad_hook_sp)
accum_grad_obj.register_hook(reduce_grad_hook) if self._overlap_sync_grad:
accum_grad_obj.register_hook(reduce_grad_hook)
_define_and_attach(param, reduce_rank) _define_and_attach(param, reduce_rank)