mirror of https://github.com/InternLM/InternLM
fix lint
parent
3dabb6d308
commit
cf7fcbbe58
|
@ -5,9 +5,9 @@ from typing import Optional
|
|||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
||||
from flash_attn.utils.distributed import reduce_scatter, all_reduce
|
||||
from flash_attn.utils.distributed import all_reduce, reduce_scatter
|
||||
from torch import nn
|
||||
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
@ -111,7 +111,6 @@ class RewardModelLinear(ScaleColumnParallelLinear):
|
|||
|
||||
|
||||
class ColumnParallelLinearTorch(ColumnParallelLinear):
|
||||
|
||||
def forward(self, x):
|
||||
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||
# we do an all_gather of x before doing the matmul.
|
||||
|
@ -123,7 +122,6 @@ class ColumnParallelLinearTorch(ColumnParallelLinear):
|
|||
|
||||
|
||||
class RowParallelLinearTorch(RowParallelLinear):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
|
||||
|
|
|
@ -6,10 +6,14 @@ from typing import Optional
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from flash_attn.ops.fused_dense import FusedDenseFunc
|
||||
from flash_attn.utils.distributed import (
|
||||
all_gather_raw,
|
||||
all_reduce_raw,
|
||||
reduce_scatter_raw,
|
||||
)
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd
|
||||
from torch.distributed import ProcessGroup
|
||||
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
|
@ -90,7 +94,6 @@ def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
|
|||
|
||||
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
|
||||
class FusedDenseFuncTorch(FusedDenseFunc):
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output, *args):
|
||||
|
|
Loading…
Reference in New Issue