pull/155/head
yingtongxiong 2023-08-02 18:27:29 +08:00
parent 3dabb6d308
commit cf7fcbbe58
2 changed files with 7 additions and 6 deletions

View File

@ -5,9 +5,9 @@ from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
from flash_attn.utils.distributed import reduce_scatter, all_reduce
from flash_attn.utils.distributed import all_reduce, reduce_scatter
from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc
@ -111,7 +111,6 @@ class RewardModelLinear(ScaleColumnParallelLinear):
class ColumnParallelLinearTorch(ColumnParallelLinear):
def forward(self, x):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
@ -123,7 +122,6 @@ class ColumnParallelLinearTorch(ColumnParallelLinear):
class RowParallelLinearTorch(RowParallelLinear):
def forward(self, x):
"""
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then

View File

@ -6,10 +6,14 @@ from typing import Optional
import torch
import torch.nn.functional as F
from flash_attn.ops.fused_dense import FusedDenseFunc
from flash_attn.utils.distributed import (
all_gather_raw,
all_reduce_raw,
reduce_scatter_raw,
)
from torch import Tensor
from torch.cuda.amp import custom_bwd
from torch.distributed import ProcessGroup
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw
from internlm.core.context import global_context as gpc
@ -90,7 +94,6 @@ def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
class FusedDenseFuncTorch(FusedDenseFunc):
@staticmethod
@custom_bwd
def backward(ctx, grad_output, *args):