|
|
|
@ -244,7 +244,7 @@ class _Layernorm3D(torch.autograd.Function):
|
|
|
|
|
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
|
|
|
|
|
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
|
|
|
|
output_parallel_mode: ParallelMode) -> Tensor:
|
|
|
|
|
"""
|
|
|
|
|
r"""
|
|
|
|
|
3D parallel Layernorm
|
|
|
|
|
|
|
|
|
|
:param input_: input maxtrix
|
|
|
|
@ -253,8 +253,9 @@ def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape:
|
|
|
|
|
:type weight: torch.tensor
|
|
|
|
|
:param bias: matrix of bias
|
|
|
|
|
:type bias: torch.tensor
|
|
|
|
|
:param normalized_shape: input shape from an expected input
|
|
|
|
|
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
|
|
|
|
:param normalized_shape: input shape from an expected input of size.
|
|
|
|
|
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
|
|
|
|
\times \ldots \times \text{normalized_shape}[-1]]`
|
|
|
|
|
If a single integer is used, it is treated as a singleton list, and this module will
|
|
|
|
|
normalize over the last dimension which is expected to be of that specific size.
|
|
|
|
|
:type normalized_shape: int
|
|
|
|
@ -282,7 +283,7 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te
|
|
|
|
|
:type tensor: torch.Tensor
|
|
|
|
|
:type dim: int
|
|
|
|
|
:type parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:return output: Splitted tensor
|
|
|
|
|
:rtype output: torch.Tensor
|
|
|
|
|
"""
|
|
|
|
@ -294,9 +295,9 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def split_batch_3d(input_: Tensor,
|
|
|
|
|
dim: int = 0,
|
|
|
|
|
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
|
|
|
|
|
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
|
|
|
|
|
dim: int = 0,
|
|
|
|
|
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
|
|
|
|
|
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
|
|
|
|
|
"""Splits 3D tensor in batch
|
|
|
|
|
:param input_: Input tensor
|
|
|
|
|
:param dim: Specified dimension in which to split
|
|
|
|
@ -333,8 +334,8 @@ class _ReduceTensor3D(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
|
|
|
|
"""
|
|
|
|
|
All-reduce the input.
|
|
|
|
|
|
|
|
|
|
All-reduce the input
|
|
|
|
|
|
|
|
|
|
:param tensor: Input tensor
|
|
|
|
|
:param parallel_mode: Parallel mode
|
|
|
|
|
"""
|
|
|
|
@ -359,7 +360,7 @@ class _AllGatherTensor3D(torch.autograd.Function):
|
|
|
|
|
def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
|
|
|
|
|
"""
|
|
|
|
|
All-reduce the gradient in backward pass.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param tensor: Input tensor
|
|
|
|
|
:param parallel_mode: Parallel mode
|
|
|
|
|
"""
|
|
|
|
@ -383,7 +384,7 @@ class _ReduceScatterTensor3D(torch.autograd.Function):
|
|
|
|
|
def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
|
|
|
|
|
"""
|
|
|
|
|
Reduce-scatter the input.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param tensor: Input tensor
|
|
|
|
|
:param dim: Dimension to scatter
|
|
|
|
|
:param parallel_mode: Parallel mode
|
|
|
|
@ -431,7 +432,8 @@ def reduce_by_batch_3d(tensor: Tensor,
|
|
|
|
|
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
|
|
|
|
:param weight_parallel_mode: weight parallel mode
|
|
|
|
|
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
|
|
|
|
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size), default to False
|
|
|
|
|
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size),
|
|
|
|
|
default to False
|
|
|
|
|
:type reduce_mean: int, optional
|
|
|
|
|
"""
|
|
|
|
|
return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean)
|
|
|
|
|