From 432bd5ee9ffb8ff2dcb744cb50ef8530d6509f80 Mon Sep 17 00:00:00 2001 From: ytxiong <45058324+yingtongxiong@users.noreply.github.com> Date: Tue, 12 Dec 2023 16:22:39 +0800 Subject: [PATCH] fix the bug so that the sequence parallel norm is all-reduced when overlap is False (#534) --- internlm/solver/optimizer/hybrid_zero_optim.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 01b40ab..eb7aae3 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -219,10 +219,7 @@ class HybridZeroOptimizer(BaseOptimizer): # flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled. self.skip_grad_reduce = False - # reduction hook is only used if overlapping communication - # if it is stage 1 without overlapping, no hook will be attached - if self._overlap_sync_grad: - self._attach_reduction_hook() + self._attach_reduction_hook() @property def zero_local_rank(self): @@ -321,12 +318,15 @@ class HybridZeroOptimizer(BaseOptimizer): # if sequence_parallel is True, # the grad of norm should be all-reduce across the tp process group - if gpc.config.parallel.sequence_parallel is True: - if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True: - accum_grad_obj_sp = get_grad_accumulate_object(param) - accum_grad_obj_sp.register_hook(reduce_grad_hook_sp) + if ( + gpc.config.parallel.sequence_parallel is True + and hasattr(param, IS_SEQUENCE_PARALLEL) + and getattr(param, IS_SEQUENCE_PARALLEL) is True + ): + accum_grad_obj.register_hook(reduce_grad_hook_sp) - accum_grad_obj.register_hook(reduce_grad_hook) + if self._overlap_sync_grad: + accum_grad_obj.register_hook(reduce_grad_hook) _define_and_attach(param, reduce_rank)