diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/nn/layer/parallel_3d/_operation.py index 6ad442788..b1e12a908 100644 --- a/colossalai/nn/layer/parallel_3d/_operation.py +++ b/colossalai/nn/layer/parallel_3d/_operation.py @@ -244,7 +244,7 @@ class _Layernorm3D(torch.autograd.Function): def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor: - """ + r""" 3D parallel Layernorm :param input_: input maxtrix @@ -253,8 +253,9 @@ def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: :type weight: torch.tensor :param bias: matrix of bias :type bias: torch.tensor - :param normalized_shape: input shape from an expected input - of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` + :param normalized_shape: input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. :type normalized_shape: int @@ -282,7 +283,7 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te :type tensor: torch.Tensor :type dim: int :type parallel_mode: colossalai.context.parallel_mode.ParallelMode - + :return output: Splitted tensor :rtype output: torch.Tensor """ @@ -294,9 +295,9 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te def split_batch_3d(input_: Tensor, - dim: int = 0, - input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT, - weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor: + dim: int = 0, + input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT, + weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor: """Splits 3D tensor in batch :param input_: Input tensor :param dim: Specified dimension in which to split @@ -333,8 +334,8 @@ class _ReduceTensor3D(torch.autograd.Function): def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor: """ - All-reduce the input. - + All-reduce the input + :param tensor: Input tensor :param parallel_mode: Parallel mode """ @@ -359,7 +360,7 @@ class _AllGatherTensor3D(torch.autograd.Function): def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: """ All-reduce the gradient in backward pass. - + :param tensor: Input tensor :param parallel_mode: Parallel mode """ @@ -383,7 +384,7 @@ class _ReduceScatterTensor3D(torch.autograd.Function): def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: """ Reduce-scatter the input. - + :param tensor: Input tensor :param dim: Dimension to scatter :param parallel_mode: Parallel mode @@ -431,7 +432,8 @@ def reduce_by_batch_3d(tensor: Tensor, :type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode :param weight_parallel_mode: weight parallel mode :type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode - :param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size), default to False + :param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size), + default to False :type reduce_mean: int, optional """ return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean) diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py index 5164bc69a..996692c21 100644 --- a/colossalai/nn/layer/parallel_3d/layers.py +++ b/colossalai/nn/layer/parallel_3d/layers.py @@ -17,7 +17,8 @@ from torch import Tensor from torch.nn import Parameter from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple -from ._operation import * +from ._operation import layernorm_3d, linear_3d, classifier_3d, split_tensor_3d +from ._operation import all_gather_tensor_3d, reduce_scatter_tensor_3d, broadcast_weight_3d_from_diagonal from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group @@ -26,8 +27,9 @@ class LayerNorm3D(ParallelLayer): r""" Layer Normalization for 3D parallelism - :param normalized_shape: input shape from an expected input - of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` + :param normalized_shape: input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. :type normalized_shape: int @@ -38,6 +40,7 @@ class LayerNorm3D(ParallelLayer): """ def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype=None): + super().__init__() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -405,7 +408,7 @@ class PatchEmbedding3D(ParallelLayer): input_ = split_tensor_3d(input_, 0, self.input_parallel_mode) output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC cls_token = self.cls_token.expand(output.shape[0], -1, -1) output = torch.cat((cls_token, output), dim=1) @@ -549,7 +552,7 @@ class VocabParallelEmbedding3D(torch.nn.Module): def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None and \ - self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0)