From c517ec5b8cdf9c675f97dcc615bfd39c2ffda010 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Mon, 6 Nov 2023 11:57:14 +0800 Subject: [PATCH] feat(model/overlap_handler.py): delete reduce_scatter_overlap switch --- configs/7B_sft.py | 2 +- internlm/model/overlap_handler.py | 9 ++++----- internlm/model/utils.py | 11 +++++++---- internlm/solver/optimizer/hybrid_zero_optim.py | 17 +++++++---------- 4 files changed, 19 insertions(+), 20 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 9928508..e85d2df 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -163,7 +163,7 @@ pipeline parallel (dict): """ parallel = dict( zero1=dict(size=-1, fsdp=False), - tensor=dict(size=4, sp="intern", intern_overlap=True, reduce_scatter_overlap=True), + tensor=dict(size=4, sp="intern", intern_overlap=True), pipeline=dict(size=1, interleaved_overlap=True), ) diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index ed0a8d2..e3198bb 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -73,11 +73,10 @@ class FSTPOverlapHandler: setattr(child, "_fstp_name", name) - if gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False): - _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" - setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight") - if child.bias is not None: - setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias") + _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" + setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight") + if child.bias is not None: + setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias") self.num_blocks = len(self.index_to_fstp_modules) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 4f197b1..556752a 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -568,7 +568,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function): total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] ) if world_size > 1: - if overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False): + if overlap_handler is not None: grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool( grad_weight, process_group, async_op=True ) @@ -621,14 +621,16 @@ class FSTPFusedDenseFunc(torch.autograd.Function): del total_weight if ctx.needs_input_grad[1]: - if world_size > 1 and not (overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False)): + if world_size > 1 and overlap_handler is None: handle_grad_weight.wait() if grad_bias is not None: handle_grad_bias.wait() return grad_input, grad_weight, grad_bias, None, None, None, None, None, None + class FSTPFusedDenseFuncTorch(FSTPFusedDenseFunc): "FusedDenseFunc for FSTP, which is optimized based on flash implementation." + @staticmethod @custom_bwd def backward(ctx, grad_output, *args): @@ -667,7 +669,7 @@ class FSTPFusedDenseFuncTorch(FSTPFusedDenseFunc): total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] ) if world_size > 1: - if overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False): + if overlap_handler is not None: grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool( grad_weight, process_group, async_op=True ) @@ -720,12 +722,13 @@ class FSTPFusedDenseFuncTorch(FSTPFusedDenseFunc): del total_weight if ctx.needs_input_grad[1]: - if world_size > 1 and not (overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False)): + if world_size > 1 and overlap_handler is None: handle_grad_weight.wait() if grad_bias is not None: handle_grad_bias.wait() return grad_input, grad_weight, grad_bias, None, None, None, None, None, None + def fused_dense_func_torch( x: Tensor, weight: Tensor, diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index b2b16dc..9a277ae 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -133,7 +133,6 @@ class HybridZeroOptimizer(BaseOptimizer): self._fstp_handler = gpc.fstp_handler else: self._fstp_handler = None - self._reduce_scatter_overlap = gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False) # iterate over the param group in the optimizer # partition these param groups for data parallel training @@ -326,7 +325,7 @@ class HybridZeroOptimizer(BaseOptimizer): # we should not only register for parameters which have _fstp_reduce_scatter_str attr. # we must keep up with reduce_grad_hook. - if self._fstp_handler is not None and self._reduce_scatter_overlap is True: + if self._fstp_handler is not None: accum_grad_obj.register_hook(accum_grad_hook) if self._overlap_sync_grad: @@ -335,7 +334,7 @@ class HybridZeroOptimizer(BaseOptimizer): _define_and_attach(param, reduce_rank) def accumulate_left_grads_after_backward(self): - if self._fstp_handler is None or self._reduce_scatter_overlap is False: + if self._fstp_handler is None: return for group_id in range(self.num_param_groups): @@ -628,18 +627,16 @@ class HybridZeroOptimizer(BaseOptimizer): for param in self._fp16_param_groups[group_id]: if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True: norm_bucket.add_to_bucket(param.grad, allow_oversize=True) - # import pdb; pdb.set_trace() if not norm_bucket.is_empty(): norm_bucket.flatten() norm_bucket.commu_handle = reduce_tensor( - tensor=norm_bucket.get_flat_tensor(), - dtype=None, - dst_rank=None, - parallel_mode=ParallelMode.TENSOR, - ) + tensor=norm_bucket.get_flat_tensor(), + dtype=None, + dst_rank=None, + parallel_mode=ParallelMode.TENSOR, + ) norm_bucket.commu_handle.wait() norm_bucket.unflatten_and_copy() - # norm_bucket.empty() # if not overlapping communication (no reduction hook is attached) # we need to manually reduce these gradients