pull/155/head
yingtongxiong 2023-08-02 18:27:29 +08:00
parent 3dabb6d308
commit cf7fcbbe58
2 changed files with 7 additions and 6 deletions

View File

@ -5,9 +5,9 @@ from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
from flash_attn.utils.distributed import reduce_scatter, all_reduce from flash_attn.utils.distributed import all_reduce, reduce_scatter
from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
@ -111,7 +111,6 @@ class RewardModelLinear(ScaleColumnParallelLinear):
class ColumnParallelLinearTorch(ColumnParallelLinear): class ColumnParallelLinearTorch(ColumnParallelLinear):
def forward(self, x): def forward(self, x):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul. # we do an all_gather of x before doing the matmul.
@ -123,7 +122,6 @@ class ColumnParallelLinearTorch(ColumnParallelLinear):
class RowParallelLinearTorch(RowParallelLinear): class RowParallelLinearTorch(RowParallelLinear):
def forward(self, x): def forward(self, x):
""" """
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then We're doing Tensor Parallel with sequence parallelism: we do the matmul and then

View File

@ -6,10 +6,14 @@ from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from flash_attn.ops.fused_dense import FusedDenseFunc from flash_attn.ops.fused_dense import FusedDenseFunc
from flash_attn.utils.distributed import (
all_gather_raw,
all_reduce_raw,
reduce_scatter_raw,
)
from torch import Tensor from torch import Tensor
from torch.cuda.amp import custom_bwd from torch.cuda.amp import custom_bwd
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
@ -90,7 +94,6 @@ def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
class FusedDenseFuncTorch(FusedDenseFunc): class FusedDenseFuncTorch(FusedDenseFunc):
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx, grad_output, *args): def backward(ctx, grad_output, *args):