diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 9928508..e85d2df 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -163,7 +163,7 @@ pipeline parallel (dict): """ parallel = dict( zero1=dict(size=-1, fsdp=False), - tensor=dict(size=4, sp="intern", intern_overlap=True, reduce_scatter_overlap=True), + tensor=dict(size=4, sp="intern", intern_overlap=True), pipeline=dict(size=1, interleaved_overlap=True), ) diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index ed0a8d2..e3198bb 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -73,11 +73,10 @@ class FSTPOverlapHandler: setattr(child, "_fstp_name", name) - if gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False): - _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" - setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight") - if child.bias is not None: - setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias") + _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" + setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight") + if child.bias is not None: + setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias") self.num_blocks = len(self.index_to_fstp_modules) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 4f197b1..556752a 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -568,7 +568,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function): total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] ) if world_size > 1: - if overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False): + if overlap_handler is not None: grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool( grad_weight, process_group, async_op=True ) @@ -621,14 +621,16 @@ class FSTPFusedDenseFunc(torch.autograd.Function): del total_weight if ctx.needs_input_grad[1]: - if world_size > 1 and not (overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False)): + if world_size > 1 and overlap_handler is None: handle_grad_weight.wait() if grad_bias is not None: handle_grad_bias.wait() return grad_input, grad_weight, grad_bias, None, None, None, None, None, None + class FSTPFusedDenseFuncTorch(FSTPFusedDenseFunc): "FusedDenseFunc for FSTP, which is optimized based on flash implementation." + @staticmethod @custom_bwd def backward(ctx, grad_output, *args): @@ -667,7 +669,7 @@ class FSTPFusedDenseFuncTorch(FSTPFusedDenseFunc): total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] ) if world_size > 1: - if overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False): + if overlap_handler is not None: grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool( grad_weight, process_group, async_op=True ) @@ -720,12 +722,13 @@ class FSTPFusedDenseFuncTorch(FSTPFusedDenseFunc): del total_weight if ctx.needs_input_grad[1]: - if world_size > 1 and not (overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False)): + if world_size > 1 and overlap_handler is None: handle_grad_weight.wait() if grad_bias is not None: handle_grad_bias.wait() return grad_input, grad_weight, grad_bias, None, None, None, None, None, None + def fused_dense_func_torch( x: Tensor, weight: Tensor, diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 1472aa8..e5927e6 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -133,7 +133,6 @@ class HybridZeroOptimizer(BaseOptimizer): self._fstp_handler = gpc.fstp_handler else: self._fstp_handler = None - self._reduce_scatter_overlap = gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False) # iterate over the param group in the optimizer # partition these param groups for data parallel training @@ -349,7 +348,7 @@ class HybridZeroOptimizer(BaseOptimizer): # we should not only register for parameters which have _fstp_reduce_scatter_str attr. # we must keep up with reduce_grad_hook. - if self._fstp_handler is not None and self._reduce_scatter_overlap is True: + if self._fstp_handler is not None: accum_grad_obj.register_hook(accum_grad_hook) if self._overlap_sync_grad: @@ -358,7 +357,7 @@ class HybridZeroOptimizer(BaseOptimizer): _define_and_attach(param, reduce_rank) def accumulate_left_grads_after_backward(self): - if self._fstp_handler is None or self._reduce_scatter_overlap is False: + if self._fstp_handler is None: return for group_id in range(self.num_param_groups): @@ -644,6 +643,27 @@ class HybridZeroOptimizer(BaseOptimizer): """ assert closure is None, "closure is not supported by step()" +<<<<<<< HEAD +======= + # do all-reduce for layernorm when sequence_parallel is True + if gpc.config.parallel.sequence_parallel is True: + for group_id in range(len(self._fp16_param_groups)): + norm_bucket = TensorBucket(size=0) + for param in self._fp16_param_groups[group_id]: + if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True: + norm_bucket.add_to_bucket(param.grad, allow_oversize=True) + if not norm_bucket.is_empty(): + norm_bucket.flatten() + norm_bucket.commu_handle = reduce_tensor( + tensor=norm_bucket.get_flat_tensor(), + dtype=None, + dst_rank=None, + parallel_mode=ParallelMode.TENSOR, + ) + norm_bucket.commu_handle.wait() + norm_bucket.unflatten_and_copy() + +>>>>>>> c517ec5b8cdf9c675f97dcc615bfd39c2ffda010 # if not overlapping communication (no reduction hook is attached) # we need to manually reduce these gradients if not self._overlap_sync_grad: