[compatibility] fixed tensor parallel compatibility with torch 1.9 (#700)

pull/712/head
Frank Lee 3 years ago committed by GitHub
parent a9b8300d54
commit eda30a058e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -58,6 +58,7 @@ def matmul_2d(
class _Classifier2D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
@ -76,7 +77,7 @@ class _Classifier2D(torch.autograd.Function):
pipeline_parallel_size: int,
tensor_parallel_size: int,
) -> Tensor:
A = A.clone().detach()
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
@ -181,6 +182,7 @@ class Matmul_AB_2D(torch.autograd.Function):
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
@ -308,6 +310,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
@ -440,6 +443,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
@ -552,6 +556,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
class _Add_Bias_2D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
@ -633,6 +638,7 @@ def add_bias_2d(input_: Tensor, bias: Tensor, output_size_per_partition: int, ro
class _Layernorm_2D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx: Any, input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode,
@ -689,6 +695,7 @@ def layernorm_2d(input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, r
class _AllGatherTensor2D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, inputs: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
@ -742,6 +749,7 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
class _ReduceTensor2D(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, parallel_mode):
return all_reduce(input_, parallel_mode)
@ -766,6 +774,7 @@ def reduce_tensor_2d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor:
class _ReduceScatterTensor2D(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, parallel_mode):
ctx.dim = dim
@ -798,6 +807,7 @@ def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMo
class _ReduceByBatch2D(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_, reduce_mean: bool = False):
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)

@ -23,25 +23,26 @@ def get_parallel_rank(parallel_mode: ParallelMode):
class _Classifier2p5D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
ctx: Any,
A: Tensor,
B: Tensor,
bias,
tesseract_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
ctx: Any,
A: Tensor,
B: Tensor,
bias,
tesseract_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
) -> Tensor:
A = A.clone().detach()
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
@ -509,6 +510,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
class _Add_Bias_2p5D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int,
@ -689,6 +691,7 @@ def layernorm_2p5d(input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int,
class _AllGatherTensor2p5D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor:
@ -777,6 +780,7 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
class _ReduceTensor2p5D(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, parallel_mode):
return all_reduce(input_, parallel_mode)
@ -801,6 +805,7 @@ def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor:
class _ReduceScatterTensor2p5D(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, parallel_mode):
ctx.dim = dim
@ -833,6 +838,7 @@ def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: Parallel
class _RreduceByBatch2p5D(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_, reduce_mean: bool = False):
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)

Loading…
Cancel
Save