fix format for dir-[parallel_3d] (#333)

pull/394/head
DouJS 3 years ago committed by Frank Lee
parent eaac03ae1d
commit cbb6436ff0

@ -244,7 +244,7 @@ class _Layernorm3D(torch.autograd.Function):
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
"""
r"""
3D parallel Layernorm
:param input_: input maxtrix
@ -253,8 +253,9 @@ def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape:
:type weight: torch.tensor
:param bias: matrix of bias
:type bias: torch.tensor
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
:param normalized_shape: input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
@ -282,7 +283,7 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te
:type tensor: torch.Tensor
:type dim: int
:type parallel_mode: colossalai.context.parallel_mode.ParallelMode
:return output: Splitted tensor
:rtype output: torch.Tensor
"""
@ -294,9 +295,9 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te
def split_batch_3d(input_: Tensor,
dim: int = 0,
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
dim: int = 0,
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
"""Splits 3D tensor in batch
:param input_: Input tensor
:param dim: Specified dimension in which to split
@ -333,8 +334,8 @@ class _ReduceTensor3D(torch.autograd.Function):
def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
"""
All-reduce the input.
All-reduce the input
:param tensor: Input tensor
:param parallel_mode: Parallel mode
"""
@ -359,7 +360,7 @@ class _AllGatherTensor3D(torch.autograd.Function):
def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
"""
All-reduce the gradient in backward pass.
:param tensor: Input tensor
:param parallel_mode: Parallel mode
"""
@ -383,7 +384,7 @@ class _ReduceScatterTensor3D(torch.autograd.Function):
def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
"""
Reduce-scatter the input.
:param tensor: Input tensor
:param dim: Dimension to scatter
:param parallel_mode: Parallel mode
@ -431,7 +432,8 @@ def reduce_by_batch_3d(tensor: Tensor,
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size), default to False
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size),
default to False
:type reduce_mean: int, optional
"""
return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean)

@ -17,7 +17,8 @@ from torch import Tensor
from torch.nn import Parameter
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import *
from ._operation import layernorm_3d, linear_3d, classifier_3d, split_tensor_3d
from ._operation import all_gather_tensor_3d, reduce_scatter_tensor_3d, broadcast_weight_3d_from_diagonal
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
@ -26,8 +27,9 @@ class LayerNorm3D(ParallelLayer):
r"""
Layer Normalization for 3D parallelism
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
:param normalized_shape: input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
@ -38,6 +40,7 @@ class LayerNorm3D(ParallelLayer):
"""
def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype=None):
super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
@ -405,7 +408,7 @@ class PatchEmbedding3D(ParallelLayer):
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1)
@ -549,7 +552,7 @@ class VocabParallelEmbedding3D(torch.nn.Module):
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None and \
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)

Loading…
Cancel
Save