diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 6777acc..56661d8 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -202,8 +202,10 @@ class NonPipelineScheduler(BaseScheduler): if return_output_label: outputs.append(_output) labels.append(_label) + if not return_output_label: outputs, labels = None, None + # Compatible for non-moe if hasattr(gpc.config.model, "num_experts"): return outputs, labels, loss, moe_loss diff --git a/internlm/model/linear.py b/internlm/model/linear.py index fbe6f14..4075e9e 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -6,17 +6,13 @@ from typing import Optional import torch import torch.nn.functional as F from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear -from flash_attn.utils.distributed import all_reduce, reduce_scatter, all_gather_raw, reduce_scatter_raw -from torch import Tensor +from flash_attn.utils.distributed import all_reduce, reduce_scatter from torch import nn -from torch.cuda.amp import custom_bwd, custom_fwd -# import fused_dense_cuda # from apex -import fused_dense_lib as fused_dense_cuda from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import Silu, fused_dense_func_torch +from internlm.model.utils import Silu, fused_dense_func_torch, fsdp_fused_dense_func class ScaleColumnParallelLinear(nn.Linear): @@ -208,116 +204,6 @@ class FeedForward(nn.Module): out = self.w3(Silu(w1_o, w2_o)) return out -class FSDPFusedDenseFunc(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, x, weight, bias, return_residual=False, process_group=None): - - ctx.compute_weight_gradient = weight.requires_grad - ctx.return_residual = return_residual - ctx.process_group = process_group - - if torch.is_autocast_enabled(): - x = x.to(dtype=torch.get_autocast_gpu_dtype()) - total_x = x.contiguous() - - world_size = gpc.get_world_size(ParallelMode.TENSOR) - if world_size > 1: - # do all_gather for weight and bias before actual computation - total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) - if bias is not None: - total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True) - handle_bias.wait() - else: - total_bias = bias - handle_weight.wait() - else: - total_weight = weight - total_bias = bias - - if torch.is_autocast_enabled(): - total_weight = total_weight.to(dtype=torch.get_autocast_gpu_dtype()) - total_bias = total_bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None - - total_weight = total_weight.contiguous() - batch_shape, n = total_x.shape[:-1], total_x.shape[-1] - batch_dim = batch_shape.numel() - # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 - if min(batch_dim, n, *total_weight.shape) > 65535 * 32: - raise RuntimeError('fused_dense only supports matrix dims <= 2M') - output = F.linear(total_x, total_weight, total_bias) - if ctx.compute_weight_gradient: - ctx.save_for_backward(x, weight) - else: - ctx.save_for_backward(weight) - return output if not return_residual else (output, x) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output, *args): - grad_output = grad_output.contiguous() - if ctx.return_residual: - grad_input, = args - grad_input = grad_input.contiguous() - process_group = ctx.process_group - if ctx.compute_weight_gradient: - x, weight = ctx.saved_tensors - total_x = x - else: - weight, = ctx.saved_tensors - total_x = None - batch_shape = grad_output.shape[:-1] - batch_dim = batch_shape.numel() - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - - world_size = gpc.get_world_size(ParallelMode.TENSOR) - if world_size > 1: - # do all-gather for weight before backward - total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) - handle_weight.wait() - else: - total_weight = weight - - if ctx.needs_input_grad[0]: - if not ctx.return_residual: - grad_input = F.linear(grad_output, total_weight.t()) - else: - grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), - grad_output, total_weight) - grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) - else: - grad_input = None - - if ctx.needs_input_grad[1]: - assert ctx.compute_weight_gradient - - grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( - total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] - ) - if world_size > 1: - grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) - if grad_bias is not None: - grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) - handle_grad_bias.wait() - handle_grad_weight.wait() - else: - grad_weight = None - grad_bias = grad_output if ctx.needs_input_grad[2] else None - return grad_input, grad_weight, grad_bias, None, None, None - - -def fsdp_fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, - return_residual: bool = False, process_group = None): - dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] - or (x.dtype == torch.float32 and torch.is_autocast_enabled())) - if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: - return FSDPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group) - else: - assert process_group is None - out = F.linear(x, weight, bias) - return out if not return_residual else (out, x) - class FSDPLinear(ColumnParallelLinear): def forward(self, x): diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 33c8c46..c884544 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -5,9 +5,7 @@ from typing import Optional import torch import torch.nn.functional as F -# from flash_attn.ops.fused_dense import FusedDenseFunc from flash_attn.utils.distributed import ( - # all_gather_raw, all_reduce_raw, reduce_scatter_raw, ) @@ -17,6 +15,7 @@ from torch.distributed import ProcessGroup import fused_dense_lib as fused_dense_cuda +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.utils.logger import get_logger @@ -90,23 +89,53 @@ def gather_forward_split_backward(input_, parallel_mode, dim): return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + Split the input and keep only the corresponding chuck to the rank. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(input_): + return _split(input_, parallel_mode=None) + + @staticmethod + def forward(ctx, input_, parallel_mode, dim): + ctx.mode = parallel_mode + ctx.dim = dim + return _split(input_, parallel_mode, dim) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output, ctx.mode, ctx.dim), None, None + + +def split_forward_gather_backward(input_, parallel_mode, dim): + return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim) + + +def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False, gather_dim: int = 0): + world_size = torch.distributed.get_world_size(process_group) + shape = list(input_.shape) + shape[gather_dim] = shape[gather_dim] * world_size + output = torch.empty(shape, dtype=input_.dtype, device=input_.device) + handle = torch.distributed.all_gather_into_tensor(output, input_.contiguous(), + group=process_group, async_op=async_op) + return output, handle + + def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias): assert my_input.dtype == grad_output.dtype grad_weight = torch.matmul(grad_output.t(), my_input) grad_bias = grad_output.sum(dim=0) if has_d_bias else None return grad_weight, grad_bias -def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False, gather_dim: int = 0): - world_size = torch.distributed.get_world_size(process_group) - shape = list(input_.shape) - shape[gather_dim] = shape[gather_dim] * world_size - # output = torch.empty(world_size * input_.shape[0], *input_.shape[1:], - # dtype=input_.dtype, device=input_.device) - output = torch.empty(shape, dtype=input_.dtype, device=input_.device) - handle = torch.distributed.all_gather_into_tensor(output, input_.contiguous(), - group=process_group, async_op=async_op) - return output, handle +# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py class FusedDenseFunc(torch.autograd.Function): @staticmethod @@ -253,6 +282,105 @@ class FusedDenseFuncTorch(FusedDenseFunc): return grad_input, grad_weight, grad_bias, None, None, None, None +class FSDPFusedDenseFunc(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, x, weight, bias, return_residual=False, process_group=None): + + ctx.compute_weight_gradient = weight.requires_grad + ctx.return_residual = return_residual + ctx.process_group = process_group + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) + total_x = x.contiguous() + + world_size = gpc.get_world_size(ParallelMode.TENSOR) + if world_size > 1: + # do all_gather for weight and bias before actual computation + total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) + if bias is not None: + total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True) + handle_bias.wait() + else: + total_bias = bias + handle_weight.wait() + else: + total_weight = weight + total_bias = bias + + if torch.is_autocast_enabled(): + total_weight = total_weight.to(dtype=torch.get_autocast_gpu_dtype()) + total_bias = total_bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None + + total_weight = total_weight.contiguous() + batch_shape, n = total_x.shape[:-1], total_x.shape[-1] + batch_dim = batch_shape.numel() + # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 + if min(batch_dim, n, *total_weight.shape) > 65535 * 32: + raise RuntimeError('fused_dense only supports matrix dims <= 2M') + output = F.linear(total_x, total_weight, total_bias) + if ctx.compute_weight_gradient: + ctx.save_for_backward(x, weight) + else: + ctx.save_for_backward(weight) + return output if not return_residual else (output, x) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output, *args): + grad_output = grad_output.contiguous() + if ctx.return_residual: + grad_input, = args + grad_input = grad_input.contiguous() + process_group = ctx.process_group + if ctx.compute_weight_gradient: + x, weight = ctx.saved_tensors + total_x = x + else: + weight, = ctx.saved_tensors + total_x = None + batch_shape = grad_output.shape[:-1] + batch_dim = batch_shape.numel() + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + + world_size = gpc.get_world_size(ParallelMode.TENSOR) + if world_size > 1: + # do all-gather for weight before backward + total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) + handle_weight.wait() + else: + total_weight = weight + + if ctx.needs_input_grad[0]: + if not ctx.return_residual: + grad_input = F.linear(grad_output, total_weight.t()) + else: + grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), + grad_output, total_weight) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + else: + grad_input = None + + if ctx.needs_input_grad[1]: + assert ctx.compute_weight_gradient + + grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( + total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] + ) + if world_size > 1: + grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) + if grad_bias is not None: + grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) + handle_grad_bias.wait() + handle_grad_weight.wait() + else: + grad_weight = None + grad_bias = grad_output if ctx.needs_input_grad[2] else None + return grad_input, grad_weight, grad_bias, None, None, None + + def fused_dense_func_torch( x: Tensor, weight: Tensor, @@ -271,33 +399,16 @@ def fused_dense_func_torch( return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim) -class _SplitForwardGatherBackward(torch.autograd.Function): - """ - Split the input and keep only the corresponding chuck to the rank. - - Args: - input_: input matrix. - parallel_mode: parallel mode. - dim: dimension - """ - - @staticmethod - def symbolic(input_): - return _split(input_, parallel_mode=None) - - @staticmethod - def forward(ctx, input_, parallel_mode, dim): - ctx.mode = parallel_mode - ctx.dim = dim - return _split(input_, parallel_mode, dim) - - @staticmethod - def backward(ctx, grad_output): - return _gather(grad_output, ctx.mode, ctx.dim), None, None - - -def split_forward_gather_backward(input_, parallel_mode, dim): - return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim) +def fsdp_fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, + return_residual: bool = False, process_group = None): + dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] + or (x.dtype == torch.float32 and torch.is_autocast_enabled())) + if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: + return FSDPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group) + else: + assert process_group is None + out = F.linear(x, weight, bias) + return out if not return_residual else (out, x) def try_import_RMSNorm():