|
|
|
@ -58,6 +58,7 @@ def matmul_2d(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _Classifier2D(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
@custom_fwd(cast_inputs=torch.float16)
|
|
|
|
|
def forward(
|
|
|
|
@ -76,7 +77,7 @@ class _Classifier2D(torch.autograd.Function):
|
|
|
|
|
pipeline_parallel_size: int,
|
|
|
|
|
tensor_parallel_size: int,
|
|
|
|
|
) -> Tensor:
|
|
|
|
|
|
|
|
|
|
A = A.clone().detach()
|
|
|
|
|
A_shape = A.shape
|
|
|
|
|
A = A.reshape((-1, A_shape[-1]))
|
|
|
|
|
B_shape = B.shape
|
|
|
|
@ -181,6 +182,7 @@ class Matmul_AB_2D(torch.autograd.Function):
|
|
|
|
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
|
|
|
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
@custom_fwd(cast_inputs=torch.float16)
|
|
|
|
|
def forward(
|
|
|
|
@ -308,6 +310,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
|
|
|
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
|
|
|
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
@custom_fwd(cast_inputs=torch.float16)
|
|
|
|
|
def forward(
|
|
|
|
@ -440,6 +443,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
|
|
|
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
|
|
|
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
@custom_fwd(cast_inputs=torch.float16)
|
|
|
|
|
def forward(
|
|
|
|
@ -552,6 +556,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _Add_Bias_2D(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
@custom_fwd(cast_inputs=torch.float16)
|
|
|
|
|
def forward(
|
|
|
|
@ -633,6 +638,7 @@ def add_bias_2d(input_: Tensor, bias: Tensor, output_size_per_partition: int, ro
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _Layernorm_2D(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
@custom_fwd(cast_inputs=torch.float32)
|
|
|
|
|
def forward(ctx: Any, input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode,
|
|
|
|
@ -689,6 +695,7 @@ def layernorm_2d(input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, r
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _AllGatherTensor2D(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
@custom_fwd(cast_inputs=torch.float16)
|
|
|
|
|
def forward(ctx: Any, inputs: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
|
|
|
|
@ -742,6 +749,7 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ReduceTensor2D(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def forward(ctx, input_, parallel_mode):
|
|
|
|
|
return all_reduce(input_, parallel_mode)
|
|
|
|
@ -766,6 +774,7 @@ def reduce_tensor_2d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ReduceScatterTensor2D(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def forward(ctx, input_, dim, parallel_mode):
|
|
|
|
|
ctx.dim = dim
|
|
|
|
@ -793,11 +802,12 @@ def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMo
|
|
|
|
|
world_size = gpc.get_world_size(parallel_mode)
|
|
|
|
|
assert dim_size % world_size == 0, \
|
|
|
|
|
f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return _ReduceScatterTensor2D.apply(tensor, dim, parallel_mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ReduceByBatch2D(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def symbolic(graph, input_, reduce_mean: bool = False):
|
|
|
|
|
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
|
|
|
|
@ -834,4 +844,4 @@ def reduce_by_batch_2d(input_, reduce_mean: bool = False) -> Tensor:
|
|
|
|
|
reduce_mean (bool, optional):
|
|
|
|
|
If set to ``True``, it will divide the output by column parallel size, default to False.
|
|
|
|
|
"""
|
|
|
|
|
return _ReduceByBatch2D.apply(input_, reduce_mean)
|
|
|
|
|
return _ReduceByBatch2D.apply(input_, reduce_mean)
|
|
|
|
|