mirror of https://github.com/hpcaitech/ColossalAI
[compatibility] fixed tensor parallel compatibility with torch 1.9 (#700)
parent
a9b8300d54
commit
eda30a058e
|
@ -58,6 +58,7 @@ def matmul_2d(
|
|||
|
||||
|
||||
class _Classifier2D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(
|
||||
|
@ -76,7 +77,7 @@ class _Classifier2D(torch.autograd.Function):
|
|||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> Tensor:
|
||||
|
||||
A = A.clone().detach()
|
||||
A_shape = A.shape
|
||||
A = A.reshape((-1, A_shape[-1]))
|
||||
B_shape = B.shape
|
||||
|
@ -181,6 +182,7 @@ class Matmul_AB_2D(torch.autograd.Function):
|
|||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(
|
||||
|
@ -308,6 +310,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
|||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(
|
||||
|
@ -440,6 +443,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
|||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(
|
||||
|
@ -552,6 +556,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
|||
|
||||
|
||||
class _Add_Bias_2D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(
|
||||
|
@ -633,6 +638,7 @@ def add_bias_2d(input_: Tensor, bias: Tensor, output_size_per_partition: int, ro
|
|||
|
||||
|
||||
class _Layernorm_2D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx: Any, input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode,
|
||||
|
@ -689,6 +695,7 @@ def layernorm_2d(input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, r
|
|||
|
||||
|
||||
class _AllGatherTensor2D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any, inputs: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
|
||||
|
@ -742,6 +749,7 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
|||
|
||||
|
||||
class _ReduceTensor2D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, parallel_mode):
|
||||
return all_reduce(input_, parallel_mode)
|
||||
|
@ -766,6 +774,7 @@ def reduce_tensor_2d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
|||
|
||||
|
||||
class _ReduceScatterTensor2D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, parallel_mode):
|
||||
ctx.dim = dim
|
||||
|
@ -798,6 +807,7 @@ def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMo
|
|||
|
||||
|
||||
class _ReduceByBatch2D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_, reduce_mean: bool = False):
|
||||
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
|
||||
|
|
|
@ -23,6 +23,7 @@ def get_parallel_rank(parallel_mode: ParallelMode):
|
|||
|
||||
|
||||
class _Classifier2p5D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(
|
||||
|
@ -41,7 +42,7 @@ class _Classifier2p5D(torch.autograd.Function):
|
|||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> Tensor:
|
||||
|
||||
A = A.clone().detach()
|
||||
A_shape = A.shape
|
||||
A = A.reshape((-1, A_shape[-1]))
|
||||
B_shape = B.shape
|
||||
|
@ -509,6 +510,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
|
|||
|
||||
|
||||
class _Add_Bias_2p5D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any, input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int,
|
||||
|
@ -689,6 +691,7 @@ def layernorm_2p5d(input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int,
|
|||
|
||||
|
||||
class _AllGatherTensor2p5D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any, inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor:
|
||||
|
@ -777,6 +780,7 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
|||
|
||||
|
||||
class _ReduceTensor2p5D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, parallel_mode):
|
||||
return all_reduce(input_, parallel_mode)
|
||||
|
@ -801,6 +805,7 @@ def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
|||
|
||||
|
||||
class _ReduceScatterTensor2p5D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, parallel_mode):
|
||||
ctx.dim = dim
|
||||
|
@ -833,6 +838,7 @@ def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: Parallel
|
|||
|
||||
|
||||
class _RreduceByBatch2p5D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_, reduce_mean: bool = False):
|
||||
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
|
||||
|
|
Loading…
Reference in New Issue