diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 6758167..3e1d078 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -5,7 +5,7 @@ SEQ_LEN = 2048 HIDDEN_SIZE = 4096 NUM_ATTENTION_HEAD = 32 MLP_RATIO = 8 / 3 -NUM_LAYER = 4 +NUM_LAYER = 32 VOCAB_SIZE = 103168 MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" @@ -155,7 +155,7 @@ tensor parallel: tensor parallel size, usually the number of GPUs per node. """ parallel = dict( zero1=-1, - tensor=dict(size=2, mode='fstp'), # the mode should be 'origin_tp' or 'fstp' + tensor=dict(size=2, mode='origin_tp'), # the mode should be 'origin_tp' or 'fstp' pipeline=dict(size=1, interleaved_overlap=True), sequence_parallel=True, ) diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 60a3d27..fbe6f14 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -54,7 +54,7 @@ class ScaleColumnParallelLinear(nn.Linear): self.process_group = process_group self.weight_scale = weight_scale - def forward(self, input): # pylint: disable=W0622 + def forward(self, input, gather_dim=0): # pylint: disable=W0622 # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # we do an all_gather of x before doing the matmul. # If not, then the input is already gathered. @@ -68,6 +68,7 @@ class ScaleColumnParallelLinear(nn.Linear): self.bias, process_group=self.process_group, sequence_parallel=gpc.config.parallel.sequence_parallel, + gather_dim=gather_dim, ) @@ -121,13 +122,13 @@ class RewardModelLinear(ScaleColumnParallelLinear): class ColumnParallelLinearTorch(ColumnParallelLinear): - def forward(self, x): + def forward(self, x, gather_dim=0): # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # we do an all_gather of x before doing the matmul. # If not, then the input is already gathered. return fused_dense_func_torch( - x, self.weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel + x, self.weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel, gather_dim=gather_dim, ) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 47d706f..56a8efa 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -405,12 +405,10 @@ class PackedFlashInternLm1D(nn.Module): if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) if hasattr(self, "head"): - # if hidden_states.ndim == 3: - # import pdb; pdb.set_trace() - # hidden_states = self.head(hidden_states, dim=1) - # else: - # hidden_states = self.head(hidden_states) - hidden_states = self.head(hidden_states) + if hidden_states.ndim == 3: + hidden_states = self.head(hidden_states, gather_dim=1) + else: + hidden_states = self.head(hidden_states) if not self.parallel_output: hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 570a86f..33c8c46 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -5,16 +5,18 @@ from typing import Optional import torch import torch.nn.functional as F -from flash_attn.ops.fused_dense import FusedDenseFunc +# from flash_attn.ops.fused_dense import FusedDenseFunc from flash_attn.utils.distributed import ( - all_gather_raw, + # all_gather_raw, all_reduce_raw, reduce_scatter_raw, ) from torch import Tensor -from torch.cuda.amp import custom_bwd +from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup +import fused_dense_lib as fused_dense_cuda + from internlm.core.context import global_context as gpc from internlm.utils.logger import get_logger @@ -94,6 +96,109 @@ def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias): grad_bias = grad_output.sum(dim=0) if has_d_bias else None return grad_weight, grad_bias +def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False, gather_dim: int = 0): + world_size = torch.distributed.get_world_size(process_group) + shape = list(input_.shape) + shape[gather_dim] = shape[gather_dim] * world_size + # output = torch.empty(world_size * input_.shape[0], *input_.shape[1:], + # dtype=input_.dtype, device=input_.device) + output = torch.empty(shape, dtype=input_.dtype, device=input_.device) + handle = torch.distributed.all_gather_into_tensor(output, input_.contiguous(), + group=process_group, async_op=async_op) + return output, handle + +class FusedDenseFunc(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, x, weight, bias, return_residual=False, process_group=None, + sequence_parallel=True, gather_dim=0): + """ + If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel + with sequence parallelism: we do an all_gather_raw of x before doing the matmul. + """ + ctx.compute_weight_gradient = weight.requires_grad + ctx.return_residual = return_residual + ctx.process_group = process_group + ctx.sequence_parallel = sequence_parallel + ctx.gather_dim = gather_dim + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) + x = x.contiguous() + if process_group is not None and sequence_parallel: + # We want to kick off the all_gather early, before weight dtype conversion + total_x, handle_x = all_gather_raw(x, process_group, async_op=True, gather_dim=gather_dim) + else: + total_x = x + + if torch.is_autocast_enabled(): + weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) + bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None + weight = weight.contiguous() + if process_group is not None and sequence_parallel: + handle_x.wait() + batch_shape, n = total_x.shape[:-1], total_x.shape[-1] + batch_dim = batch_shape.numel() + # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 + if min(batch_dim, n, *weight.shape) > 65535 * 32: + raise RuntimeError('fused_dense only supports matrix dims <= 2M') + output = F.linear(total_x, weight, bias) + if ctx.compute_weight_gradient: + ctx.save_for_backward(x, weight) + else: + ctx.save_for_backward(weight) + return output if not return_residual else (output, x) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output, *args): + grad_output = grad_output.contiguous() + if ctx.return_residual: + grad_input, = args + grad_input = grad_input.contiguous() + process_group = ctx.process_group + sequence_parallel = ctx.sequence_parallel + gather_dim = ctx.gather_dim + + if ctx.compute_weight_gradient: + x, weight = ctx.saved_tensors + if process_group is not None and sequence_parallel: + total_x, handle_x = all_gather_raw(x, process_group, async_op=True, gather_dim=gather_dim) + else: + total_x = x + else: + weight, = ctx.saved_tensors + total_x = None + batch_shape = grad_output.shape[:-1] + batch_dim = batch_shape.numel() + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + if ctx.needs_input_grad[0]: + if not ctx.return_residual: + grad_input = F.linear(grad_output, weight.t()) + else: + grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), + grad_output, weight) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + if process_group is not None: + reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw + grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) + else: + grad_input = None + if ctx.needs_input_grad[1]: + assert ctx.compute_weight_gradient + if process_group is not None and sequence_parallel: + handle_x.wait() + grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( + total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] + ) + else: + grad_weight = None + grad_bias = grad_output if ctx.needs_input_grad[2] else None + if process_group is not None and ctx.needs_input_grad[0]: + handle_grad_input.wait() + return grad_input, grad_weight, grad_bias, None, None, None, None + # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py class FusedDenseFuncTorch(FusedDenseFunc): @@ -108,10 +213,11 @@ class FusedDenseFuncTorch(FusedDenseFunc): grad_input = grad_input.contiguous() process_group = ctx.process_group sequence_parallel = ctx.sequence_parallel + gather_dim = ctx.gather_dim if ctx.compute_weight_gradient: x, weight = ctx.saved_tensors if process_group is not None and sequence_parallel: - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + total_x, handle_x = all_gather_raw(x, process_group, async_op=True, gather_dim=gather_dim) else: total_x = x else: @@ -144,7 +250,7 @@ class FusedDenseFuncTorch(FusedDenseFunc): grad_bias = grad_output if ctx.needs_input_grad[2] else None if process_group is not None and ctx.needs_input_grad[0]: handle_grad_input.wait() - return grad_input, grad_weight, grad_bias, None, None, None + return grad_input, grad_weight, grad_bias, None, None, None, None def fused_dense_func_torch( @@ -154,14 +260,15 @@ def fused_dense_func_torch( return_residual: bool = False, process_group: Optional[ProcessGroup] = None, sequence_parallel: bool = True, + gather_dim: int = 0, ): dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( x.dtype == torch.float32 and torch.is_autocast_enabled() ) if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: - return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel) + return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim) else: - return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel) + return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim) class _SplitForwardGatherBackward(torch.autograd.Function):