fix the bug so that the sequence parallel norm is all-reduced when overlap is False (#534)

pull/538/head
ytxiong 2023-12-12 16:22:39 +08:00 committed by GitHub
parent d904730be7
commit 432bd5ee9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 9 deletions

View File

@ -219,10 +219,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
self.skip_grad_reduce = False
# reduction hook is only used if overlapping communication
# if it is stage 1 without overlapping, no hook will be attached
if self._overlap_sync_grad:
self._attach_reduction_hook()
self._attach_reduction_hook()
@property
def zero_local_rank(self):
@ -321,12 +318,15 @@ class HybridZeroOptimizer(BaseOptimizer):
# if sequence_parallel is True,
# the grad of norm should be all-reduce across the tp process group
if gpc.config.parallel.sequence_parallel is True:
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
accum_grad_obj_sp = get_grad_accumulate_object(param)
accum_grad_obj_sp.register_hook(reduce_grad_hook_sp)
if (
gpc.config.parallel.sequence_parallel is True
and hasattr(param, IS_SEQUENCE_PARALLEL)
and getattr(param, IS_SEQUENCE_PARALLEL) is True
):
accum_grad_obj.register_hook(reduce_grad_hook_sp)
accum_grad_obj.register_hook(reduce_grad_hook)
if self._overlap_sync_grad:
accum_grad_obj.register_hook(reduce_grad_hook)
_define_and_attach(param, reduce_rank)