diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 7089cee..c0dfcf9 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -5,9 +5,9 @@ from typing import Optional import torch import torch.nn.functional as F -from torch import nn from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear -from flash_attn.utils.distributed import reduce_scatter, all_reduce +from flash_attn.utils.distributed import all_reduce, reduce_scatter +from torch import nn from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc @@ -111,7 +111,6 @@ class RewardModelLinear(ScaleColumnParallelLinear): class ColumnParallelLinearTorch(ColumnParallelLinear): - def forward(self, x): # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # we do an all_gather of x before doing the matmul. @@ -123,7 +122,6 @@ class ColumnParallelLinearTorch(ColumnParallelLinear): class RowParallelLinearTorch(RowParallelLinear): - def forward(self, x): """ We're doing Tensor Parallel with sequence parallelism: we do the matmul and then diff --git a/internlm/model/utils.py b/internlm/model/utils.py index ba12578..c0a8b19 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -6,10 +6,14 @@ from typing import Optional import torch import torch.nn.functional as F from flash_attn.ops.fused_dense import FusedDenseFunc +from flash_attn.utils.distributed import ( + all_gather_raw, + all_reduce_raw, + reduce_scatter_raw, +) from torch import Tensor from torch.cuda.amp import custom_bwd from torch.distributed import ProcessGroup -from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw from internlm.core.context import global_context as gpc @@ -90,7 +94,6 @@ def linear_bias_wgrad_torch(input, grad_output, has_d_bias): # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py class FusedDenseFuncTorch(FusedDenseFunc): - @staticmethod @custom_bwd def backward(ctx, grad_output, *args):