From b3def4c1628dbba652ffb9b089eeb7be9de584af Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Tue, 31 Oct 2023 20:40:58 +0800 Subject: [PATCH] fix(optimizer/hybrid_zero_optim.py): add reduce_scatter_overlap switch --- configs/7B_sft.py | 4 ++-- internlm/model/overlap_handler.py | 9 +++++---- internlm/model/utils.py | 4 ++-- internlm/solver/optimizer/hybrid_zero_optim.py | 5 +++-- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 2d6a3be..b34a838 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -57,7 +57,7 @@ data = dict( # defaults to 0, means disable evaluate valid_every=50, pack_sample_into_one=True, - total_steps=50, + total_steps=10, skip_batches="", rampup_batch_size="", # Datasets with less than 50 rows will be discarded @@ -163,7 +163,7 @@ pipeline parallel (dict): """ parallel = dict( zero1=dict(size=-1, fsdp=False), - tensor=dict(size=8, sp="intern", intern_overlap=True), + tensor=dict(size=8, sp="intern", intern_overlap=True, reduce_scatter_overlap=True), pipeline=dict(size=1, interleaved_overlap=True), ) diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index 7805e11..418c4aa 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -70,10 +70,11 @@ class FSTPOverlapHandler: setattr(child, "_fstp_name", name) - _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" - setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight") - if child.bias is not None: - setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias") + if gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False): + _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" + setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight") + if child.bias is not None: + setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias") self.num_blocks = len(self.index_to_fstp_modules) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 982c0e0..63dd09d 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -568,7 +568,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function): total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] ) if world_size > 1: - if overlap_handler is not None: + if overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False): grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool( grad_weight, process_group, async_op=True ) @@ -621,7 +621,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function): del total_weight if ctx.needs_input_grad[1]: - if world_size > 1 and overlap_handler is None: + if world_size > 1 and not (overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False)): handle_grad_weight.wait() if grad_bias is not None: handle_grad_bias.wait() diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 0ab6396..a8b524a 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -133,6 +133,7 @@ class HybridZeroOptimizer(BaseOptimizer): self._fstp_handler = gpc.fstp_handler else: self._fstp_handler = None + self._reduce_scatter_overlap = gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False) # iterate over the param group in the optimizer # partition these param groups for data parallel training @@ -348,7 +349,7 @@ class HybridZeroOptimizer(BaseOptimizer): # we should not only register for parameters which have _fstp_reduce_scatter_str attr. # we must keep up with reduce_grad_hook. - if self._fstp_handler is not None: + if self._fstp_handler is not None and self._reduce_scatter_overlap is True: accum_grad_obj.register_hook(accum_grad_hook) if self._overlap_sync_grad: @@ -357,7 +358,7 @@ class HybridZeroOptimizer(BaseOptimizer): _define_and_attach(param, reduce_rank) def accumulate_left_grads_after_backward(self): - if self._fstp_handler is None: + if self._fstp_handler is None or self._reduce_scatter_overlap is False: return for group_id in range(self.num_param_groups):