From 9b1265c59107edd44063684c96446af20892fd25 Mon Sep 17 00:00:00 2001 From: yingtongxiong <974106207@qq.com> Date: Mon, 6 Nov 2023 10:45:08 +0800 Subject: [PATCH] modify the sp allreduce and support tf32 for fstp linear --- .gitignore | 2 + configs/generate.py | 8 ++ internlm/model/utils.py | 102 +++++++++++++++++- .../solver/optimizer/hybrid_zero_optim.py | 43 ++++---- 4 files changed, 129 insertions(+), 26 deletions(-) diff --git a/.gitignore b/.gitignore index 9bdc7ec..ef18a4a 100644 --- a/.gitignore +++ b/.gitignore @@ -152,6 +152,8 @@ fstp_logs/ configs/7B_train/* configs/13B_train/* configs/30B_train/* +configs/test_loss/* +loss_tensorboard/* atb pip diff --git a/configs/generate.py b/configs/generate.py index 038998c..5f044e7 100644 --- a/configs/generate.py +++ b/configs/generate.py @@ -39,6 +39,14 @@ for idx, root_name in enumerate(root_names): log_name = root_name + "_" + output_file_name[:-3] + skip = True + + if sp_mode == "intern" and intern_overlap[i] is True: + skip = False + + if skip: + continue + print(log_name) command = f"srun -p llm_s -N 8 -n 64 --ntasks-per-node=8 --gpus-per-task=1 --time=30 python train.py --config {write_file} --profiling 2>&1 | tee ./fstp_logs/{log_name}.log" process = subprocess.Popen(command, shell=True, executable='/bin/bash') diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 63dd09d..4f197b1 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -627,6 +627,104 @@ class FSTPFusedDenseFunc(torch.autograd.Function): handle_grad_bias.wait() return grad_input, grad_weight, grad_bias, None, None, None, None, None, None +class FSTPFusedDenseFuncTorch(FSTPFusedDenseFunc): + "FusedDenseFunc for FSTP, which is optimized based on flash implementation." + @staticmethod + @custom_bwd + def backward(ctx, grad_output, *args): + grad_output = grad_output.contiguous() + if ctx.return_residual: + (grad_input,) = args + grad_input = grad_input.contiguous() + process_group = ctx.process_group + overlap_handler = ctx.overlap_handler + module = ctx.module + + if ctx.compute_weight_gradient: + x, weight, bias = ctx.saved_tensors + total_x = x + else: + weight, bias = ctx.saved_tensors + total_x = None + batch_shape = grad_output.shape[:-1] + batch_dim = batch_shape.numel() + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + + world_size = gpc.get_world_size(ParallelMode.TENSOR) + if world_size > 1: + if overlap_handler is not None: + total_weight = gpc.fstp_handler.get_all_gather_memory(module=module) + else: + total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) + handle_weight.wait() + else: + total_weight = weight + + # compute weight grad + if ctx.needs_input_grad[1]: + assert ctx.compute_weight_gradient + grad_weight, grad_bias = linear_bias_wgrad_torch( + total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] + ) + if world_size > 1: + if overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False): + grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool( + grad_weight, process_group, async_op=True + ) + assert hasattr(weight, "_fstp_reduce_scatter_str") + overlap_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = ( + handle_grad_weight, + grad_weight_async, + ) + grad_weight = overlap_handler.get_zero_by_shape( + ( + grad_weight.shape[0] // torch.distributed.get_world_size(process_group), + *grad_weight.shape[1:], + ), + dtype=grad_weight.dtype, + device=grad_weight.device, + ) + if grad_bias is not None: + grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool( + grad_bias, process_group, async_op=True + ) + assert hasattr(bias, "_fstp_reduce_scatter_str") + overlap_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = ( + handle_grad_bias, + grad_bias_async, + ) + grad_bias = overlap_handler.get_zero_by_shape( + ( + grad_bias.shape[0] // torch.distributed.get_world_size(process_group), + *grad_bias.shape[1:], + ), + dtype=grad_bias.dtype, + device=grad_bias.device, + ) + else: + grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) + if grad_bias is not None: + grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) + else: + grad_weight = None + grad_bias = grad_output if ctx.needs_input_grad[2] else None + + if ctx.needs_input_grad[0]: + if not ctx.return_residual: + grad_input = F.linear(grad_output, total_weight.t()) + else: + grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, total_weight) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + else: + grad_input = None + del total_weight + + if ctx.needs_input_grad[1]: + if world_size > 1 and not (overlap_handler is not None and gpc.config.parallel["tensor"].get("reduce_scatter_overlap", False)): + handle_grad_weight.wait() + if grad_bias is not None: + handle_grad_bias.wait() + return grad_input, grad_weight, grad_bias, None, None, None, None, None, None def fused_dense_func_torch( x: Tensor, @@ -683,9 +781,7 @@ def fstp_fused_dense_func( if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, module, handler) else: - assert process_group is None - out = F.linear(x, weight, bias) - return out if not return_residual else (out, x) + return FSTPFusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, module, handler) def try_import_RMSNorm(): diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 1472aa8..b2b16dc 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -308,15 +308,6 @@ class HybridZeroOptimizer(BaseOptimizer): reduce_rank=reduce_rank, ) - def reduction_sp_func(): - handle = reduce_tensor( - param.grad, - dtype=None, - dst_rank=reduce_rank, - parallel_mode=ParallelMode.TENSOR, - ) - handle.wait() - # define hook # NOT IMPORTANT BUT GOOD TO KNOW: # args here is not grad, but allow_unreacable and accumulate_grad @@ -328,25 +319,11 @@ class HybridZeroOptimizer(BaseOptimizer): def accum_grad_hook(*args): # pylint: disable=W0613 reduce_scatter_checker() - # define hook for sequence_parallel - def reduce_grad_hook_sp(*args): # pylint: disable=W0613 - if self.skip_grad_reduce is False: - reduction_sp_func() - # get the AccumulateGrad object of the param itself # If these objects are not kept, reduction hooks may not be attached successfully. accum_grad_obj = get_grad_accumulate_object(param) self._grad_store.add_accumulate_grad_object(accum_grad_obj) - # if sequence_parallel is True, - # the grad of norm should be all-reduce across the tp process group - if ( - gpc.config.parallel.sequence_parallel is True - and hasattr(param, IS_SEQUENCE_PARALLEL) - and getattr(param, IS_SEQUENCE_PARALLEL) is True - ): - accum_grad_obj.register_hook(reduce_grad_hook_sp) - # we should not only register for parameters which have _fstp_reduce_scatter_str attr. # we must keep up with reduce_grad_hook. if self._fstp_handler is not None and self._reduce_scatter_overlap is True: @@ -644,6 +621,26 @@ class HybridZeroOptimizer(BaseOptimizer): """ assert closure is None, "closure is not supported by step()" + # do all-reduce for layernorm when sequence_parallel is True + if gpc.config.parallel.sequence_parallel is True: + for group_id in range(len(self._fp16_param_groups)): + norm_bucket = TensorBucket(size=0) + for param in self._fp16_param_groups[group_id]: + if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True: + norm_bucket.add_to_bucket(param.grad, allow_oversize=True) + # import pdb; pdb.set_trace() + if not norm_bucket.is_empty(): + norm_bucket.flatten() + norm_bucket.commu_handle = reduce_tensor( + tensor=norm_bucket.get_flat_tensor(), + dtype=None, + dst_rank=None, + parallel_mode=ParallelMode.TENSOR, + ) + norm_bucket.commu_handle.wait() + norm_bucket.unflatten_and_copy() + # norm_bucket.empty() + # if not overlapping communication (no reduction hook is attached) # we need to manually reduce these gradients if not self._overlap_sync_grad: