mirror of https://github.com/hpcaitech/ColossalAI
fix format for dir-[parallel_3d] (#333)
parent
eaac03ae1d
commit
cbb6436ff0
|
@ -244,7 +244,7 @@ class _Layernorm3D(torch.autograd.Function):
|
|||
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
|
||||
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode) -> Tensor:
|
||||
"""
|
||||
r"""
|
||||
3D parallel Layernorm
|
||||
|
||||
:param input_: input maxtrix
|
||||
|
@ -253,8 +253,9 @@ def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape:
|
|||
:type weight: torch.tensor
|
||||
:param bias: matrix of bias
|
||||
:type bias: torch.tensor
|
||||
:param normalized_shape: input shape from an expected input
|
||||
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
||||
:param normalized_shape: input shape from an expected input of size.
|
||||
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
||||
\times \ldots \times \text{normalized_shape}[-1]]`
|
||||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
:type normalized_shape: int
|
||||
|
@ -282,7 +283,7 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te
|
|||
:type tensor: torch.Tensor
|
||||
:type dim: int
|
||||
:type parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||
|
||||
|
||||
:return output: Splitted tensor
|
||||
:rtype output: torch.Tensor
|
||||
"""
|
||||
|
@ -294,9 +295,9 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te
|
|||
|
||||
|
||||
def split_batch_3d(input_: Tensor,
|
||||
dim: int = 0,
|
||||
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
|
||||
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
|
||||
dim: int = 0,
|
||||
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
|
||||
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
|
||||
"""Splits 3D tensor in batch
|
||||
:param input_: Input tensor
|
||||
:param dim: Specified dimension in which to split
|
||||
|
@ -333,8 +334,8 @@ class _ReduceTensor3D(torch.autograd.Function):
|
|||
|
||||
def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
||||
"""
|
||||
All-reduce the input.
|
||||
|
||||
All-reduce the input
|
||||
|
||||
:param tensor: Input tensor
|
||||
:param parallel_mode: Parallel mode
|
||||
"""
|
||||
|
@ -359,7 +360,7 @@ class _AllGatherTensor3D(torch.autograd.Function):
|
|||
def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
|
||||
"""
|
||||
All-reduce the gradient in backward pass.
|
||||
|
||||
|
||||
:param tensor: Input tensor
|
||||
:param parallel_mode: Parallel mode
|
||||
"""
|
||||
|
@ -383,7 +384,7 @@ class _ReduceScatterTensor3D(torch.autograd.Function):
|
|||
def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
|
||||
"""
|
||||
Reduce-scatter the input.
|
||||
|
||||
|
||||
:param tensor: Input tensor
|
||||
:param dim: Dimension to scatter
|
||||
:param parallel_mode: Parallel mode
|
||||
|
@ -431,7 +432,8 @@ def reduce_by_batch_3d(tensor: Tensor,
|
|||
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||
:param weight_parallel_mode: weight parallel mode
|
||||
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size), default to False
|
||||
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size),
|
||||
default to False
|
||||
:type reduce_mean: int, optional
|
||||
"""
|
||||
return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean)
|
||||
|
|
|
@ -17,7 +17,8 @@ from torch import Tensor
|
|||
from torch.nn import Parameter
|
||||
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
from ._operation import *
|
||||
from ._operation import layernorm_3d, linear_3d, classifier_3d, split_tensor_3d
|
||||
from ._operation import all_gather_tensor_3d, reduce_scatter_tensor_3d, broadcast_weight_3d_from_diagonal
|
||||
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
|
||||
|
||||
|
||||
|
@ -26,8 +27,9 @@ class LayerNorm3D(ParallelLayer):
|
|||
r"""
|
||||
Layer Normalization for 3D parallelism
|
||||
|
||||
:param normalized_shape: input shape from an expected input
|
||||
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
||||
:param normalized_shape: input shape from an expected input of size.
|
||||
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
||||
\times \ldots \times \text{normalized_shape}[-1]]`
|
||||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
:type normalized_shape: int
|
||||
|
@ -38,6 +40,7 @@ class LayerNorm3D(ParallelLayer):
|
|||
"""
|
||||
|
||||
def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype=None):
|
||||
|
||||
super().__init__()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
|
@ -405,7 +408,7 @@ class PatchEmbedding3D(ParallelLayer):
|
|||
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
|
||||
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
|
||||
if self.flatten:
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
|
||||
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
|
||||
output = torch.cat((cls_token, output), dim=1)
|
||||
|
@ -549,7 +552,7 @@ class VocabParallelEmbedding3D(torch.nn.Module):
|
|||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None and \
|
||||
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
||||
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
|
|
Loading…
Reference in New Issue