diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index b2b16dc..1472aa8 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -308,6 +308,15 @@ class HybridZeroOptimizer(BaseOptimizer): reduce_rank=reduce_rank, ) + def reduction_sp_func(): + handle = reduce_tensor( + param.grad, + dtype=None, + dst_rank=reduce_rank, + parallel_mode=ParallelMode.TENSOR, + ) + handle.wait() + # define hook # NOT IMPORTANT BUT GOOD TO KNOW: # args here is not grad, but allow_unreacable and accumulate_grad @@ -319,11 +328,25 @@ class HybridZeroOptimizer(BaseOptimizer): def accum_grad_hook(*args): # pylint: disable=W0613 reduce_scatter_checker() + # define hook for sequence_parallel + def reduce_grad_hook_sp(*args): # pylint: disable=W0613 + if self.skip_grad_reduce is False: + reduction_sp_func() + # get the AccumulateGrad object of the param itself # If these objects are not kept, reduction hooks may not be attached successfully. accum_grad_obj = get_grad_accumulate_object(param) self._grad_store.add_accumulate_grad_object(accum_grad_obj) + # if sequence_parallel is True, + # the grad of norm should be all-reduce across the tp process group + if ( + gpc.config.parallel.sequence_parallel is True + and hasattr(param, IS_SEQUENCE_PARALLEL) + and getattr(param, IS_SEQUENCE_PARALLEL) is True + ): + accum_grad_obj.register_hook(reduce_grad_hook_sp) + # we should not only register for parameters which have _fstp_reduce_scatter_str attr. # we must keep up with reduce_grad_hook. if self._fstp_handler is not None and self._reduce_scatter_overlap is True: @@ -621,26 +644,6 @@ class HybridZeroOptimizer(BaseOptimizer): """ assert closure is None, "closure is not supported by step()" - # do all-reduce for layernorm when sequence_parallel is True - if gpc.config.parallel.sequence_parallel is True: - for group_id in range(len(self._fp16_param_groups)): - norm_bucket = TensorBucket(size=0) - for param in self._fp16_param_groups[group_id]: - if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True: - norm_bucket.add_to_bucket(param.grad, allow_oversize=True) - # import pdb; pdb.set_trace() - if not norm_bucket.is_empty(): - norm_bucket.flatten() - norm_bucket.commu_handle = reduce_tensor( - tensor=norm_bucket.get_flat_tensor(), - dtype=None, - dst_rank=None, - parallel_mode=ParallelMode.TENSOR, - ) - norm_bucket.commu_handle.wait() - norm_bucket.unflatten_and_copy() - # norm_bucket.empty() - # if not overlapping communication (no reduction hook is attached) # we need to manually reduce these gradients if not self._overlap_sync_grad: