diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 6920b8f..3487324 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -318,14 +318,13 @@ class HybridZeroOptimizer(BaseOptimizer): if self.skip_grad_reduce is False: reduction_sp_func() - - accum_grad_obj.register_hook(reduce_grad_hook) - # if sequence_parallel is True, the grad of norm should be all-reduce across the tp process group if gpc.config.parallel.sequence_parallel is True: if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True: accum_grad_obj_sp = get_grad_accumulate_object(param) accum_grad_obj_sp.register_hook(reduce_grad_hook_sp) + + accum_grad_obj.register_hook(reduce_grad_hook) _define_and_attach(param, reduce_rank)