From e551b4dffc156ff0147bd12476cbf6f4b6930e5c Mon Sep 17 00:00:00 2001 From: yingtongxiong <974106207@qq.com> Date: Mon, 31 Jul 2023 20:36:04 +0800 Subject: [PATCH] fix lint --- internlm/initialize/launch.py | 12 +++-- internlm/model/linear.py | 48 ++++++++++++------- internlm/model/multi_head_attention.py | 1 + internlm/model/utils.py | 65 ++++++++++++++------------ 4 files changed, 74 insertions(+), 52 deletions(-) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 4c7ea7a..9e81b0a 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -155,15 +155,21 @@ def args_sanity_check(): elif gpc.config.model.dtype in ("torch.float16", "torch.half"): gpc.config.model.dtype = torch.float16 elif gpc.config.model.dtype == "torch.float32": - assert gpc.config.model.use_flash_attn == False, "when using float32, the use_flash_attn must be False" + assert gpc.config.model.use_flash_attn is False, "when using float32, the use_flash_attn must be False" gpc.config.model.dtype = torch.float32 elif gpc.config.model.dtype == "torch.tf32": - assert gpc.config.model.use_flash_attn == False, "when using tf32, the use_flash_attn must be False" + assert gpc.config.model.use_flash_attn is False, "when using tf32, the use_flash_attn must be False" torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True gpc.config.model.dtype = torch.float32 else: - assert gpc.config.model.dtype in ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"] + assert gpc.config.model.dtype in [ + "torch.float16", + "torch.half", + "torch.bfloat16", + "torch.float32", + "torch.tf32", + ] if gpc.is_rank_for_log(): logger.info("+" * 15 + " Model Info " + "+" * 15) # pylint: disable=W1201 diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 918fff6..36704e2 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -5,12 +5,13 @@ from typing import Optional import torch import torch.nn.functional as F -from torch.distributed import ProcessGroup from torch import nn +from torch.distributed import ProcessGroup from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import fused_dense_func_torch, reduce_scatter, all_reduce +from internlm.model.utils import all_reduce, fused_dense_func_torch, reduce_scatter + class ScaleColumnParallelLinear(nn.Linear): """ @@ -109,15 +110,20 @@ class RewardModelLinear(ScaleColumnParallelLinear): class ColumnParallelLinear(nn.Linear): - - def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup, - bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None: + def __init__( + self, + in_features: int, + out_features: int, + process_group: ProcessGroup, + bias: bool = True, + sequence_parallel=True, + device=None, + dtype=None, + ) -> None: world_size = torch.distributed.get_world_size(process_group) if out_features % world_size != 0: - raise ValueError(f'out_features ({out_features}) must be divisible by ' - f'world_size ({world_size})') - super().__init__(in_features, out_features // world_size, bias=bias, - device=device, dtype=dtype) + raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})") + super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype) self.process_group = process_group self.sequence_parallel = sequence_parallel @@ -126,22 +132,28 @@ class ColumnParallelLinear(nn.Linear): # we do an all_gather of x before doing the matmul. # If not, then the input is already gathered. - return fused_dense_func_torch(x, self.weight, self.bias, process_group=self.process_group, - sequence_parallel=self.sequence_parallel) + return fused_dense_func_torch( + x, self.weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel + ) class RowParallelLinear(nn.Linear): - - def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup, - bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None: + def __init__( + self, + in_features: int, + out_features: int, + process_group: ProcessGroup, + bias: bool = True, + sequence_parallel=True, + device=None, + dtype=None, + ) -> None: world_size = torch.distributed.get_world_size(process_group) rank = torch.distributed.get_rank(process_group) if in_features % world_size != 0: - raise ValueError(f'in_features ({in_features}) must be divisible by ' - f'world_size ({world_size})') + raise ValueError(f"in_features ({in_features}) must be divisible by " f"world_size ({world_size})") # Only rank 0 will have bias - super().__init__(in_features // world_size, out_features, bias=bias and rank == 0, - device=device, dtype=dtype) + super().__init__(in_features // world_size, out_features, bias=bias and rank == 0, device=device, dtype=dtype) self.process_group = process_group self.sequence_parallel = sequence_parallel diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index fe3e152..ba0c267 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -19,6 +19,7 @@ from internlm.core.context import global_context as gpc from internlm.model.embedding import RotaryEmbedding from internlm.model.linear import ColumnParallelLinear, RowParallelLinear + class MHA(nn.Module): """ Multi-head self-attention and cross-attention. diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 9d430c7..8fddf1e 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -1,16 +1,15 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from typing import Optional + import torch import torch.nn.functional as F - -from typing import Optional +from flash_attn.ops.fused_dense import FusedDenseFunc from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup -from flash_attn.ops.fused_dense import FusedDenseFunc - from internlm.core.context import global_context as gpc @@ -80,29 +79,32 @@ class _GatherForwardSplitBackward(torch.autograd.Function): def gather_forward_split_backward(input_, parallel_mode, dim): return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) + def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): world_size = torch.distributed.get_world_size(process_group) - output = torch.empty(world_size * input_.shape[0], *input_.shape[1:], - dtype=input_.dtype, device=input_.device) - handle = torch.distributed.all_gather_into_tensor(output, input_.contiguous(), - group=process_group, async_op=async_op) + output = torch.empty(world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device) + handle = torch.distributed.all_gather_into_tensor( + output, input_.contiguous(), group=process_group, async_op=async_op + ) return output, handle + def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): world_size = torch.distributed.get_world_size(process_group) assert input_.shape[0] % world_size == 0 - output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:], - dtype=input_.dtype, device=input_.device) - handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(), - group=process_group, - async_op=async_op) + output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device) + handle = torch.distributed.reduce_scatter_tensor( + output, input_.contiguous(), group=process_group, async_op=async_op + ) return output, handle + def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): input_ = input_.contiguous() handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) return input_, handle + class ReduceScatterFunc(torch.autograd.Function): """Reduce scatter the input from the sequence parallel region and concatenate.""" @@ -137,6 +139,7 @@ reduce_scatter = ReduceScatterFunc.apply # Supports autograd, but does not support async all_reduce = AllReduceFunc.apply + def linear_bias_wgrad_torch(input, grad_output, has_d_bias): assert input.dtype == grad_output.dtype grad_weight = torch.matmul(grad_output.t(), input) @@ -145,11 +148,9 @@ def linear_bias_wgrad_torch(input, grad_output, has_d_bias): class FusedDenseFuncTorch(torch.autograd.Function): - @staticmethod @custom_fwd - def forward(ctx, x, weight, bias, return_residual=False, process_group=None, - sequence_parallel=True): + def forward(ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True): """ If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel with sequence parallelism: we do an all_gather_raw of x before doing the matmul. @@ -178,7 +179,7 @@ class FusedDenseFuncTorch(torch.autograd.Function): batch_dim = batch_shape.numel() # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 if min(batch_dim, n, *weight.shape) > 65535 * 32: - raise RuntimeError('fused_dense only supports matrix dims <= 2M') + raise RuntimeError("fused_dense only supports matrix dims <= 2M") output = F.linear(total_x, weight, bias) if ctx.compute_weight_gradient: ctx.save_for_backward(x, weight) @@ -191,7 +192,7 @@ class FusedDenseFuncTorch(torch.autograd.Function): def backward(ctx, grad_output, *args): grad_output = grad_output.contiguous() if ctx.return_residual: - grad_input, = args + (grad_input,) = args grad_input = grad_input.contiguous() process_group = ctx.process_group sequence_parallel = ctx.sequence_parallel @@ -202,7 +203,7 @@ class FusedDenseFuncTorch(torch.autograd.Function): else: total_x = x else: - weight, = ctx.saved_tensors + (weight,) = ctx.saved_tensors total_x = None batch_shape = grad_output.shape[:-1] batch_dim = batch_shape.numel() @@ -211,8 +212,7 @@ class FusedDenseFuncTorch(torch.autograd.Function): if not ctx.return_residual: grad_input = F.linear(grad_output, weight.t()) else: - grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), - grad_output, weight) + grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight) grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) if process_group is not None: reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw @@ -234,15 +234,18 @@ class FusedDenseFuncTorch(torch.autograd.Function): return grad_input, grad_weight, grad_bias, None, None, None - -def fused_dense_func_torch(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, - return_residual: bool = False, process_group: Optional[ProcessGroup] = None, - sequence_parallel: bool = True): - dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] - or (x.dtype == torch.float32 and torch.is_autocast_enabled())) +def fused_dense_func_torch( + x: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + return_residual: bool = False, + process_group: Optional[ProcessGroup] = None, + sequence_parallel: bool = True, +): + dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( + x.dtype == torch.float32 and torch.is_autocast_enabled() + ) if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: - return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group, - sequence_parallel) + return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel) else: - return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, - sequence_parallel) \ No newline at end of file + return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)