diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 36704e2..7089cee 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -6,11 +6,12 @@ from typing import Optional import torch import torch.nn.functional as F from torch import nn -from torch.distributed import ProcessGroup +from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear +from flash_attn.utils.distributed import reduce_scatter, all_reduce from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import all_reduce, fused_dense_func_torch, reduce_scatter +from internlm.model.utils import fused_dense_func_torch class ScaleColumnParallelLinear(nn.Linear): @@ -109,23 +110,7 @@ class RewardModelLinear(ScaleColumnParallelLinear): ) -class ColumnParallelLinear(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - process_group: ProcessGroup, - bias: bool = True, - sequence_parallel=True, - device=None, - dtype=None, - ) -> None: - world_size = torch.distributed.get_world_size(process_group) - if out_features % world_size != 0: - raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})") - super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype) - self.process_group = process_group - self.sequence_parallel = sequence_parallel +class ColumnParallelLinearTorch(ColumnParallelLinear): def forward(self, x): # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: @@ -137,25 +122,7 @@ class ColumnParallelLinear(nn.Linear): ) -class RowParallelLinear(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - process_group: ProcessGroup, - bias: bool = True, - sequence_parallel=True, - device=None, - dtype=None, - ) -> None: - world_size = torch.distributed.get_world_size(process_group) - rank = torch.distributed.get_rank(process_group) - if in_features % world_size != 0: - raise ValueError(f"in_features ({in_features}) must be divisible by " f"world_size ({world_size})") - # Only rank 0 will have bias - super().__init__(in_features // world_size, out_features, bias=bias and rank == 0, device=device, dtype=dtype) - self.process_group = process_group - self.sequence_parallel = sequence_parallel +class RowParallelLinearTorch(RowParallelLinear): def forward(self, x): """ @@ -198,7 +165,7 @@ class FeedForward(nn.Module): hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) - self.w1 = ColumnParallelLinear( + self.w1 = ColumnParallelLinearTorch( in_features, hidden_features, process_group, @@ -207,10 +174,10 @@ class FeedForward(nn.Module): device=device, dtype=dtype, ) - self.w2 = ColumnParallelLinear( + self.w2 = ColumnParallelLinearTorch( in_features, hidden_features, process_group, bias, sequence_parallel=False, device=device, dtype=dtype ) - self.w3 = RowParallelLinear( + self.w3 = RowParallelLinearTorch( hidden_features, out_features, process_group, diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index ba0c267..096c4e6 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -17,7 +17,7 @@ from torch import nn from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc from internlm.model.embedding import RotaryEmbedding -from internlm.model.linear import ColumnParallelLinear, RowParallelLinear +from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch class MHA(nn.Module): @@ -78,7 +78,7 @@ class MHA(nn.Module): self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) # notice here should change bias=True - self.Wqkv = ColumnParallelLinear( + self.Wqkv = ColumnParallelLinearTorch( embed_dim, 3 * embed_dim, process_group, @@ -95,7 +95,7 @@ class MHA(nn.Module): ) # output projection always have the bias (for now) - self.out_proj = RowParallelLinear( + self.out_proj = RowParallelLinearTorch( embed_dim, embed_dim, process_group, sequence_parallel=sequence_parallel, **factory_kwargs ) # need to assign tp attribute so that internlm know it is tensor parallel module diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 3c0dd9b..ba12578 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -7,8 +7,9 @@ import torch import torch.nn.functional as F from flash_attn.ops.fused_dense import FusedDenseFunc from torch import Tensor -from torch.cuda.amp import custom_bwd, custom_fwd +from torch.cuda.amp import custom_bwd from torch.distributed import ProcessGroup +from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw from internlm.core.context import global_context as gpc @@ -80,67 +81,6 @@ def gather_forward_split_backward(input_, parallel_mode, dim): return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) -# the following communicators are adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/distributed.py -def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): - world_size = torch.distributed.get_world_size(process_group) - output = torch.empty(world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device) - handle = torch.distributed.all_gather_into_tensor( - output, input_.contiguous(), group=process_group, async_op=async_op - ) - return output, handle - - -def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): - world_size = torch.distributed.get_world_size(process_group) - assert input_.shape[0] % world_size == 0 - output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device) - handle = torch.distributed.reduce_scatter_tensor( - output, input_.contiguous(), group=process_group, async_op=async_op - ) - return output, handle - - -def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): - input_ = input_.contiguous() - handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) - return input_, handle - - -class ReduceScatterFunc(torch.autograd.Function): - """Reduce scatter the input from the sequence parallel region and concatenate.""" - - @staticmethod - def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: - ctx.process_group = process_group - output, _ = reduce_scatter_raw(input_, process_group) - return output - - @staticmethod - def backward(ctx, grad_output: Tensor): - grad_input, _ = all_gather_raw(grad_output, ctx.process_group) - return grad_input, None - - -class AllReduceFunc(torch.autograd.Function): - """Gather the input from sequence parallel region and concatenate.""" - - @staticmethod - def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: - ctx.process_group = process_group - output, _ = all_reduce_raw(input_, process_group) - return output - - @staticmethod - def backward(ctx, grad_output: Tensor): - return grad_output, None - - -# Supports autograd, but does not support async -reduce_scatter = ReduceScatterFunc.apply -# Supports autograd, but does not support async -all_reduce = AllReduceFunc.apply - - def linear_bias_wgrad_torch(input, grad_output, has_d_bias): assert input.dtype == grad_output.dtype grad_weight = torch.matmul(grad_output.t(), input) @@ -149,45 +89,7 @@ def linear_bias_wgrad_torch(input, grad_output, has_d_bias): # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py -class FusedDenseFuncTorch(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward(ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True): - """ - If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel - with sequence parallelism: we do an all_gather_raw of x before doing the matmul. - """ - ctx.compute_weight_gradient = weight.requires_grad - ctx.return_residual = return_residual - ctx.process_group = process_group - ctx.sequence_parallel = sequence_parallel - - if torch.is_autocast_enabled(): - x = x.to(dtype=torch.get_autocast_gpu_dtype()) - x = x.contiguous() - if process_group is not None and sequence_parallel: - # We want to kick off the all_gather early, before weight dtype conversion - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - else: - total_x = x - - if torch.is_autocast_enabled(): - weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) - bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None - weight = weight.contiguous() - if process_group is not None and sequence_parallel: - handle_x.wait() - batch_shape, n = total_x.shape[:-1], total_x.shape[-1] - batch_dim = batch_shape.numel() - # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 - if min(batch_dim, n, *weight.shape) > 65535 * 32: - raise RuntimeError("fused_dense only supports matrix dims <= 2M") - output = F.linear(total_x, weight, bias) - if ctx.compute_weight_gradient: - ctx.save_for_backward(x, weight) - else: - ctx.save_for_backward(weight) - return output if not return_residual else (output, x) +class FusedDenseFuncTorch(FusedDenseFunc): @staticmethod @custom_bwd