From 01a80cd86d442a5a9c147a25dd63832a6370443c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=82=A2=E3=83=9E=E3=83=87=E3=82=A6=E3=82=B9?= Date: Wed, 29 Dec 2021 23:32:10 +0800 Subject: [PATCH] Hotfix/Colossalai layers (#92) * optimized 1d layer apis; reorganized nn.layer modules; fixed tests * fixed 2.5d runtime issue * reworked split batch, now called in trainer.schedule.load_batch Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com> --- benchmark/cifar/configs/vit_1d.py | 2 +- benchmark/cifar/train.py | 8 +- benchmark/imagenet100/configs/vit_1d.py | 2 +- benchmark/imagenet100/train.py | 8 +- benchmark/imagenet1k/configs/vit_1d.py | 2 +- benchmark/imagenet1k/train.py | 8 +- colossalai/constants.py | 4 + colossalai/context/parallel_context.py | 6 +- .../initializer_1d.py | 4 +- colossalai/engine/schedule/_base_schedule.py | 8 +- colossalai/nn/layer/__init__.py | 8 +- colossalai/nn/layer/colossalai_layer.py | 231 ------------------ .../nn/layer/colossalai_layer/__init__.py | 7 + .../nn/layer/colossalai_layer/_utils.py | 19 ++ .../nn/layer/colossalai_layer/dropout.py | 23 ++ .../nn/layer/colossalai_layer/embedding.py | 107 ++++++++ .../nn/layer/colossalai_layer/linear.py | 97 ++++++++ .../layer/colossalai_layer/normalization.py | 35 +++ colossalai/nn/layer/fused_bias_gelu.py | 35 --- colossalai/nn/layer/parallel_1d/__init__.py | 4 +- colossalai/nn/layer/parallel_1d/_utils.py | 13 +- colossalai/nn/layer/parallel_1d/layers.py | 184 +++++++++++++- colossalai/nn/layer/parallel_2d/__init__.py | 4 +- colossalai/nn/layer/parallel_2d/_operation.py | 33 ++- colossalai/nn/layer/parallel_2d/layers.py | 8 +- colossalai/nn/layer/parallel_2p5d/__init__.py | 4 +- .../nn/layer/parallel_2p5d/_operation.py | 214 +++++++++++----- colossalai/nn/layer/parallel_2p5d/layers.py | 27 +- colossalai/nn/layer/parallel_3d/__init__.py | 4 +- colossalai/nn/layer/parallel_3d/_operation.py | 26 +- colossalai/nn/layer/parallel_3d/layers.py | 8 +- colossalai/nn/layer/utils/__init__.py | 7 + .../{_common_utils.py => utils/common.py} | 7 +- colossalai/nn/layer/vanilla/layers.py | 2 +- colossalai/nn/loss/__init__.py | 6 +- colossalai/nn/loss/loss_2d.py | 11 +- colossalai/nn/loss/loss_2p5d.py | 11 +- colossalai/nn/loss/loss_3d.py | 12 +- colossalai/nn/metric/__init__.py | 6 +- colossalai/nn/metric/accuracy_2d.py | 3 +- colossalai/nn/metric/accuracy_2p5d.py | 3 +- colossalai/nn/metric/accuracy_3d.py | 3 +- colossalai/trainer/hooks/_metric_hook.py | 21 +- colossalai/utils/__init__.py | 38 ++- colossalai/utils/common.py | 23 +- model_zoo/vit/vit.py | 152 ++++-------- tests/test_comm/test_comm.py | 9 +- tests/test_context/test_2d_init.py | 9 +- tests/test_context/test_2p5d_init.py | 4 +- tests/test_context/test_3d_init.py | 5 +- .../test_cifar_with_data_pipeline_tensor.py | 13 +- .../test_engine/test_engine_apex_amp.py | 28 +-- .../test_engine/test_engine_naive_amp.py | 29 +-- .../test_engine/test_engine_no_amp.py | 28 +-- .../test_engine/test_engine_torch_amp.py | 29 +-- tests/test_layers/test_1d/test_1d.py | 12 +- tests/test_layers/test_2d/test_2d.py | 13 +- tests/test_layers/test_2p5d/test_2p5d.py | 16 +- tests/test_layers/test_3d/test_3d.py | 7 +- .../test_sequence/test_sequence.py | 9 +- .../test_pipeline/resnet_config.py | 1 + tests/test_trainer/test_pipeline/test_p2p.py | 12 +- .../test_pipeline/test_partition.py | 9 +- .../test_pipeline/test_pipeline_schedule.py | 25 +- .../test_trainer_with_non_pipe_schedule.py | 8 +- .../test_trainer_with_pipe_schedule.py | 8 +- .../test_utils/test_gradient_accumluation.py | 24 +- .../test_zero_level_2.py | 18 +- .../test_zero_level_3.py | 18 +- .../test_vit_2d_level_2.py | 12 +- .../test_vit_2d_level_3.py | 12 +- 71 files changed, 1033 insertions(+), 773 deletions(-) delete mode 100644 colossalai/nn/layer/colossalai_layer.py create mode 100644 colossalai/nn/layer/colossalai_layer/__init__.py create mode 100644 colossalai/nn/layer/colossalai_layer/_utils.py create mode 100644 colossalai/nn/layer/colossalai_layer/dropout.py create mode 100644 colossalai/nn/layer/colossalai_layer/embedding.py create mode 100644 colossalai/nn/layer/colossalai_layer/linear.py create mode 100644 colossalai/nn/layer/colossalai_layer/normalization.py delete mode 100644 colossalai/nn/layer/fused_bias_gelu.py create mode 100644 colossalai/nn/layer/utils/__init__.py rename colossalai/nn/layer/{_common_utils.py => utils/common.py} (91%) diff --git a/benchmark/cifar/configs/vit_1d.py b/benchmark/cifar/configs/vit_1d.py index 34eb7d50a..1731abc1e 100644 --- a/benchmark/cifar/configs/vit_1d.py +++ b/benchmark/cifar/configs/vit_1d.py @@ -2,7 +2,7 @@ BATCH_SIZE = 512 LEARNING_RATE = 2e-3 WEIGHT_DECAY = 3e-2 -TENSOR_PARALLEL_SIZE = 4 +TENSOR_PARALLEL_SIZE = 2 TENSOR_PARALLEL_MODE = '1d' NUM_EPOCHS = 200 diff --git a/benchmark/cifar/train.py b/benchmark/cifar/train.py index 4a1d87758..4ffa22968 100644 --- a/benchmark/cifar/train.py +++ b/benchmark/cifar/train.py @@ -72,13 +72,11 @@ def train_cifar(): os.mkdir(log_path) logger.log_to_file(log_path) - tp = gpc.config.parallel.tensor.mode - - model = vit_lite_depth7_patch4_32(tensor_parallel=tp) + model = vit_lite_depth7_patch4_32() train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size) - criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp) + criterion = CrossEntropyLoss(label_smoothing=0.1) optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) @@ -107,7 +105,7 @@ def train_cifar(): LogMetricByStepHook(), # LogTimingByEpochHook(timer=timer, logger=logger), # LogMemoryByEpochHook(logger=logger), - AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)), + AccuracyHook(accuracy_func=Accuracy()), LossHook(), ThroughputHook(), LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False) diff --git a/benchmark/imagenet100/configs/vit_1d.py b/benchmark/imagenet100/configs/vit_1d.py index 07bb5fb66..bd90e1e84 100644 --- a/benchmark/imagenet100/configs/vit_1d.py +++ b/benchmark/imagenet100/configs/vit_1d.py @@ -4,7 +4,7 @@ TOTAL_BATCH_SIZE = 4096 LEARNING_RATE = 3e-3 WEIGHT_DECAY = 0.3 -TENSOR_PARALLEL_SIZE = 4 +TENSOR_PARALLEL_SIZE = 2 TENSOR_PARALLEL_MODE = '1d' NUM_EPOCHS = 300 diff --git a/benchmark/imagenet100/train.py b/benchmark/imagenet100/train.py index fece6d1a6..58ad3b15e 100644 --- a/benchmark/imagenet100/train.py +++ b/benchmark/imagenet100/train.py @@ -159,14 +159,12 @@ def train_imagenet(): os.mkdir(log_path) logger.log_to_file(log_path) - tp = gpc.config.parallel.tensor.mode - - model = vit_small_patch16_224(tensor_parallel=tp, num_classes=100, init_method='jax') + model = vit_small_patch16_224(num_classes=100, init_method='jax') train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size) test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size) - criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp) + criterion = CrossEntropyLoss(label_smoothing=0.1) optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) @@ -192,7 +190,7 @@ def train_imagenet(): LogMetricByStepHook(), # LogTimingByEpochHook(timer=timer, logger=logger), # LogMemoryByEpochHook(logger=logger), - AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)), + AccuracyHook(accuracy_func=Accuracy()), LossHook(), ThroughputHook(), LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True) diff --git a/benchmark/imagenet1k/configs/vit_1d.py b/benchmark/imagenet1k/configs/vit_1d.py index adddceb3a..d447d10b1 100644 --- a/benchmark/imagenet1k/configs/vit_1d.py +++ b/benchmark/imagenet1k/configs/vit_1d.py @@ -4,7 +4,7 @@ TOTAL_BATCH_SIZE = 4096 LEARNING_RATE = 3e-3 WEIGHT_DECAY = 0.3 -TENSOR_PARALLEL_SIZE = 4 +TENSOR_PARALLEL_SIZE = 2 TENSOR_PARALLEL_MODE = '1d' NUM_EPOCHS = 300 diff --git a/benchmark/imagenet1k/train.py b/benchmark/imagenet1k/train.py index 989dff2aa..d9b9ade99 100644 --- a/benchmark/imagenet1k/train.py +++ b/benchmark/imagenet1k/train.py @@ -159,14 +159,12 @@ def train_imagenet(): os.mkdir(log_path) logger.log_to_file(log_path) - tp = gpc.config.parallel.tensor.mode - - model = vit_small_patch16_224(tensor_parallel=tp, num_classes=1000, init_method='jax') + model = vit_small_patch16_224(num_classes=1000, init_method='jax') train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size) test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size) - criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp) + criterion = CrossEntropyLoss(label_smoothing=0.1) optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) @@ -192,7 +190,7 @@ def train_imagenet(): LogMetricByStepHook(), # LogTimingByEpochHook(timer=timer, logger=logger), # LogMemoryByEpochHook(logger=logger), - AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)), + AccuracyHook(accuracy_func=Accuracy()), LossHook(), ThroughputHook(), LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True) diff --git a/colossalai/constants.py b/colossalai/constants.py index 874c53d72..58a94437a 100644 --- a/colossalai/constants.py +++ b/colossalai/constants.py @@ -2,6 +2,7 @@ # -*- encoding: utf-8 -*- ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence'] +TENSOR_PARALLEL_MODE = 'tensor_parallel_mode' # intializer INITIALIZER_MAPPING = { @@ -16,6 +17,9 @@ INITIALIZER_MAPPING = { 'sequence': 'Initializer_Sequence' } +# 1D parallel +PARALLEL_INPUT_1D = 'parallel_input_1d' + # 2D paralllel SUMMA_DIM = 'SUMMA_DIM' diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py index f3ebb1eaa..f76f4d60e 100644 --- a/colossalai/context/parallel_context.py +++ b/colossalai/context/parallel_context.py @@ -1,17 +1,18 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import os import random from typing import Union import numpy as np import torch import torch.distributed as dist - -from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING +from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING, TENSOR_PARALLEL_MODE from colossalai.context.config import Config from colossalai.logging import get_dist_logger from colossalai.registry import DIST_GROUP_INITIALIZER + from .parallel_mode import ParallelMode from .random import add_seed, get_seeds, set_mode @@ -386,6 +387,7 @@ class ParallelContext: if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']: tensor_parallel_mode = parallel_config['tensor']['mode'] assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}" + os.environ[TENSOR_PARALLEL_MODE] = str(tensor_parallel_mode) self.check_sanity() pg_init = [] diff --git a/colossalai/context/process_group_initializer/initializer_1d.py b/colossalai/context/process_group_initializer/initializer_1d.py index 1b487aba1..edd60c085 100644 --- a/colossalai/context/process_group_initializer/initializer_1d.py +++ b/colossalai/context/process_group_initializer/initializer_1d.py @@ -1,12 +1,13 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- - +import os import torch.distributed as dist from colossalai.context import Config from colossalai.registry import DIST_GROUP_INITIALIZER from .process_group_initializer import ProcessGroupInitializer from ..parallel_mode import ParallelMode +from colossalai.constants import PARALLEL_INPUT_1D @DIST_GROUP_INITIALIZER.register_module @@ -29,6 +30,7 @@ class Initializer_1D(ProcessGroupInitializer): process_group = None group_world_size = None mode = ParallelMode.PARALLEL_1D + os.environ[PARALLEL_INPUT_1D] = '' for i in range(self.num_group): ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/engine/schedule/_base_schedule.py index aceee4e6c..411f1861b 100644 --- a/colossalai/engine/schedule/_base_schedule.py +++ b/colossalai/engine/schedule/_base_schedule.py @@ -10,7 +10,7 @@ from typing import Iterable, Union, List, Callable from .._base_engine import Engine from colossalai.logging import get_dist_logger from colossalai.utils import get_current_device - +from colossalai.nn.layer import split_batch class BaseSchedule(ABC): """A basic helper class to control the process of training or evaluation. @@ -59,7 +59,11 @@ class BaseSchedule(ABC): else: data, label = batch_data - data, label = self._to_list(data), self._to_list(label) + if isinstance(label, (tuple, list)): + self.batch_size = label[0].size(0) + else: + self.batch_size = label.size(0) + data, label = self._to_list(split_batch(data)), self._to_list(split_batch(label)) return self._move_to_device(data), self._move_to_device(label) def pre_processing(self, engine: Engine): diff --git a/colossalai/nn/layer/__init__.py b/colossalai/nn/layer/__init__.py index a04dece91..86961dd93 100644 --- a/colossalai/nn/layer/__init__.py +++ b/colossalai/nn/layer/__init__.py @@ -1,3 +1,9 @@ from .colossalai_layer import * -from .fused_bias_gelu import bias_gelu_impl +from .parallel_1d import * +from .parallel_2d import * +from .parallel_2p5d import * +from .parallel_3d import * +from .parallel_sequence import * +from .utils import * +from .vanilla import * from .wrapper import * diff --git a/colossalai/nn/layer/colossalai_layer.py b/colossalai/nn/layer/colossalai_layer.py deleted file mode 100644 index 3a185ae15..000000000 --- a/colossalai/nn/layer/colossalai_layer.py +++ /dev/null @@ -1,231 +0,0 @@ -import math -from typing import Callable, Optional - -from colossalai.utils import get_current_device -from torch import dtype, nn -from torch.nn.modules.activation import * -from torch.nn.modules.adaptive import * -from torch.nn.modules.batchnorm import * -from torch.nn.modules.channelshuffle import * -from torch.nn.modules.conv import * -from torch.nn.modules.distance import * -from torch.nn.modules.dropout import * -from torch.nn.modules.flatten import * -from torch.nn.modules.fold import * -from torch.nn.modules.instancenorm import * -from torch.nn.modules.linear import * -from torch.nn.modules.normalization import * -from torch.nn.modules.padding import * -from torch.nn.modules.pixelshuffle import * -from torch.nn.modules.pooling import * -from torch.nn.modules.rnn import * -from torch.nn.modules.sparse import * -from torch.nn.modules.transformer import * -from torch.nn.modules.upsampling import * - -from .. import init as init - -from .vanilla import * -from .parallel_1d import * -from .parallel_2d import * -from .parallel_2p5d import * -from .parallel_3d import * -from .parallel_sequence import * - -_parallel_linear = {'1d_col': Linear1D_Col, '1d_row': Linear1D_Row, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D} - -_parallel_classifier = { - None: VanillaClassifier, - '1d': VanillaClassifier, - '2d': Classifier2D, - '2.5d': Classifier2p5D, - '3d': Classifier3D -} - -_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D} - -_parallel_embedding = {'3d': Embedding3D} - -_parallel_patchembedding = { - None: VanillaPatchEmbedding, - '1d': VanillaPatchEmbedding, - '2d': PatchEmbedding2D, - '2.5d': PatchEmbedding2p5D, - '3d': PatchEmbedding3D -} - - -class Linear(nn.Module): - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - tensor_parallel: Optional[str] = None, - **kwargs) -> None: - super().__init__() - if tensor_parallel is None: - self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype) - weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features) - if bias: - bias_initializer(self.layer.bias, fan_in=in_features) - else: - self.layer = _parallel_linear[tensor_parallel]( - in_features, - out_features, - bias=bias, - dtype=dtype, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - **kwargs, - ) - - @property - def weight(self): - return self.layer.weight - - @property - def bias(self): - return self.layer.bias - - def forward(self, *args): - return self.layer(*args) - - -class LayerNorm(nn.Module): - def __init__(self, normalized_shape: int, eps=1e-05, dtype=None, tensor_parallel: Optional[str] = None) -> None: - super().__init__() - if tensor_parallel in [None, '1d']: - self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype) - else: - self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype) - - @property - def weight(self): - return self.norm.weight - - @property - def bias(self): - return self.norm.bias - - def forward(self, *args): - return self.norm(*args) - - -class Embedding(nn.Module): - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: dtype = None, - weight_initializer: Callable = init.normal_(), - tensor_parallel: Optional[str] = None, - *args, - **kwargs) -> None: - super().__init__() - if tensor_parallel in [None, '1d']: - self.embed = nn.Embedding(num_embeddings, - embedding_dim, - padding_idx=padding_idx, - device=get_current_device(), - dtype=dtype, - *args, - **kwargs) - weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) - else: - self.embed = _parallel_embedding[tensor_parallel]( - num_embeddings, - embedding_dim, - padding_idx=padding_idx, - dtype=dtype, - weight_initializer=weight_initializer, - *args, - **kwargs, - ) - - @property - def weight(self): - return self.embed.weight - - def forward(self, *args): - return self.embed(*args) - - -class PatchEmbedding(nn.Module): - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - dtype: dtype = None, - flatten: bool = True, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_(), - tensor_parallel: Optional[str] = None) -> None: - super().__init__() - self.embed = _parallel_patchembedding[tensor_parallel]( - img_size, - patch_size, - in_chans, - embed_size, - dtype=dtype, - flatten=flatten, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - position_embed_initializer=position_embed_initializer, - ) - - @property - def weight(self): - return self.embed.weight - - @property - def bias(self): - return self.embed.bias - - @property - def pos_embed(self): - return self.embed.pos_embed - - @property - def cls_token(self): - return self.embed.cls_token - - def forward(self, *args): - return self.embed(*args) - - -class Classifier(nn.Module): - def __init__(self, - in_features: int, - num_classes: int, - weight: nn.Parameter = None, - bias: bool = True, - dtype: dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - tensor_parallel: Optional[str] = None) -> None: - super().__init__() - self.layer = _parallel_classifier[tensor_parallel]( - in_features, - num_classes, - weight=weight, - bias=bias, - dtype=dtype, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - ) - - @property - def weight(self): - return self.layer.weight - - @property - def bias(self): - return self.layer.bias - - def forward(self, *args): - return self.layer(*args) diff --git a/colossalai/nn/layer/colossalai_layer/__init__.py b/colossalai/nn/layer/colossalai_layer/__init__.py new file mode 100644 index 000000000..54ed567eb --- /dev/null +++ b/colossalai/nn/layer/colossalai_layer/__init__.py @@ -0,0 +1,7 @@ +from ._utils import split_batch +from .dropout import Dropout +from .embedding import Embedding, PatchEmbedding +from .linear import Classifier, Linear +from .normalization import LayerNorm + +__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'split_batch'] diff --git a/colossalai/nn/layer/colossalai_layer/_utils.py b/colossalai/nn/layer/colossalai_layer/_utils.py new file mode 100644 index 000000000..8b996c860 --- /dev/null +++ b/colossalai/nn/layer/colossalai_layer/_utils.py @@ -0,0 +1,19 @@ +from torch import Tensor + +from ..parallel_2d._operation import split_tensor_2d +from ..parallel_2p5d._operation import split_tensor_2p5d +from ..parallel_3d._operation import split_tensor_3d +from ..utils import get_tensor_parallel_mode + +_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_tensor_3d} + + +def split_batch(input_) -> Tensor: + tensor_parallel_mode = get_tensor_parallel_mode() + if tensor_parallel_mode in _parallel_split_batch: + if isinstance(input_, (tuple, list)): + return tuple(map(_parallel_split_batch[tensor_parallel_mode], input_)) + else: + return _parallel_split_batch[tensor_parallel_mode](input_) + else: + return input_ diff --git a/colossalai/nn/layer/colossalai_layer/dropout.py b/colossalai/nn/layer/colossalai_layer/dropout.py new file mode 100644 index 000000000..ff86e0745 --- /dev/null +++ b/colossalai/nn/layer/colossalai_layer/dropout.py @@ -0,0 +1,23 @@ +from contextlib import nullcontext + +import torch.nn as nn +from colossalai.context import ParallelMode, seed +from colossalai.utils import conditional_context + +from ..parallel_1d import * +from ..utils import get_tensor_parallel_mode + + +class Dropout(nn.Module): + def __init__(self, p: float = 0.5, inplace: bool = False) -> None: + super().__init__() + self.tensor_parallel = get_tensor_parallel_mode() + if self.tensor_parallel == '1d': + self.drop = Dropout1D(p, inplace) + else: + self.drop = nn.Dropout(p, inplace) + + def forward(self, *args): + cm = nullcontext() if self.tensor_parallel in ['None', '1d'] else seed(ParallelMode.TENSOR) + with cm: + return self.drop(*args) diff --git a/colossalai/nn/layer/colossalai_layer/embedding.py b/colossalai/nn/layer/colossalai_layer/embedding.py new file mode 100644 index 000000000..6a580a29d --- /dev/null +++ b/colossalai/nn/layer/colossalai_layer/embedding.py @@ -0,0 +1,107 @@ +import math +from typing import Callable, Optional + +from colossalai.utils import get_current_device +from torch import dtype, nn + +from ... import init as init +from ..parallel_1d import * +from ..parallel_2d import * +from ..parallel_2p5d import * +from ..parallel_3d import * +from ..utils import get_tensor_parallel_mode +from ..vanilla import * + +_parallel_embedding = {'1d': Embedding1D, '2d': Embedding2D, '2.5d': Embedding2p5D, '3d': Embedding3D} + +_parallel_patchembedding = { + 'None': VanillaPatchEmbedding, + '1d': VanillaPatchEmbedding, + '2d': PatchEmbedding2D, + '2.5d': PatchEmbedding2p5D, + '3d': PatchEmbedding3D +} + + +class Embedding(nn.Module): + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs) -> None: + super().__init__() + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel == 'None': + self.embed = nn.Embedding(num_embeddings, + embedding_dim, + padding_idx=padding_idx, + device=get_current_device(), + dtype=dtype, + *args, + **kwargs) + weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) + else: + self.embed = _parallel_embedding[tensor_parallel]( + num_embeddings, + embedding_dim, + padding_idx=padding_idx, + dtype=dtype, + weight_initializer=weight_initializer, + *args, + **kwargs, + ) + + @property + def weight(self): + return self.embed.weight + + def forward(self, *args): + return self.embed(*args) + + +class PatchEmbedding(nn.Module): + def __init__(self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + dtype: dtype = None, + flatten: bool = True, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_()) -> None: + super().__init__() + tensor_parallel = get_tensor_parallel_mode() + self.embed = _parallel_patchembedding[tensor_parallel]( + img_size, + patch_size, + in_chans, + embed_size, + dtype=dtype, + flatten=flatten, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + position_embed_initializer=position_embed_initializer, + ) + + @property + def weight(self): + return self.embed.weight + + @property + def bias(self): + return self.embed.bias + + @property + def pos_embed(self): + return self.embed.pos_embed + + @property + def cls_token(self): + return self.embed.cls_token + + def forward(self, *args): + return self.embed(*args) diff --git a/colossalai/nn/layer/colossalai_layer/linear.py b/colossalai/nn/layer/colossalai_layer/linear.py new file mode 100644 index 000000000..7c78941a2 --- /dev/null +++ b/colossalai/nn/layer/colossalai_layer/linear.py @@ -0,0 +1,97 @@ +import math +from typing import Callable, Optional + +from colossalai.nn.layer.parallel_1d.layers import Classifier1D +from colossalai.utils import get_current_device +from torch import dtype, nn + +from ... import init as init +from ..parallel_1d import * +from ..parallel_2d import * +from ..parallel_2p5d import * +from ..parallel_3d import * +from ..utils import get_tensor_parallel_mode +from ..vanilla import * + +_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D} + +_parallel_classifier = { + 'None': VanillaClassifier, + '1d': Classifier1D, + '2d': Classifier2D, + '2.5d': Classifier2p5D, + '3d': Classifier3D +} + + +class Linear(nn.Module): + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + **kwargs) -> None: + super().__init__() + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel == 'None': + self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype) + weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features) + if bias: + bias_initializer(self.layer.bias, fan_in=in_features) + else: + self.layer = _parallel_linear[tensor_parallel]( + in_features, + out_features, + bias=bias, + dtype=dtype, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + **kwargs, + ) + + @property + def weight(self): + return self.layer.weight + + @property + def bias(self): + return self.layer.bias + + def forward(self, *args): + return self.layer(*args) + + +class Classifier(nn.Module): + def __init__( + self, + in_features: int, + num_classes: int, + weight: nn.Parameter = None, + bias: bool = True, + dtype: dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1) + ) -> None: + super().__init__() + self.layer = _parallel_classifier[get_tensor_parallel_mode()]( + in_features, + num_classes, + weight=weight, + bias=bias, + dtype=dtype, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + ) + + @property + def weight(self): + return self.layer.weight + + @property + def bias(self): + return self.layer.bias + + def forward(self, *args): + return self.layer(*args) diff --git a/colossalai/nn/layer/colossalai_layer/normalization.py b/colossalai/nn/layer/colossalai_layer/normalization.py new file mode 100644 index 000000000..f1dab93f9 --- /dev/null +++ b/colossalai/nn/layer/colossalai_layer/normalization.py @@ -0,0 +1,35 @@ +from typing import Optional + +from colossalai.utils import get_current_device +from torch import nn + +from ... import init as init +from ..parallel_1d import * +from ..parallel_2d import * +from ..parallel_2p5d import * +from ..parallel_3d import * +from ..utils import get_tensor_parallel_mode +from ..vanilla import * + +_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D} + + +class LayerNorm(nn.Module): + def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None: + super().__init__() + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel in ['None', '1d']: + self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype) + else: + self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype) + + @property + def weight(self): + return self.norm.weight + + @property + def bias(self): + return self.norm.bias + + def forward(self, *args): + return self.norm(*args) diff --git a/colossalai/nn/layer/fused_bias_gelu.py b/colossalai/nn/layer/fused_bias_gelu.py deleted file mode 100644 index e92041534..000000000 --- a/colossalai/nn/layer/fused_bias_gelu.py +++ /dev/null @@ -1,35 +0,0 @@ -# adapted from Megatron-LM -# https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/megatron/model/fused_bias_gelu.py - -import torch - -@torch.jit.script -def bias_gelu(bias, y): - x = bias + y - return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) - -# gradient of tanh approximation of gelu -# gradient of actual gelu is: -# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) -@torch.jit.script -def bias_gelu_back(g, bias, y): - x = bias + y - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 - ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) - return ff*g - -class GeLUFunction(torch.autograd.Function): - @staticmethod - # bias is an optional argument - def forward(ctx, input, bias): - ctx.save_for_backward(input, bias) - return bias_gelu(bias, input) - - @staticmethod - def backward(ctx, grad_output): - input, bias = ctx.saved_tensors - tmp = bias_gelu_back(grad_output, bias, input) - return tmp, tmp - -bias_gelu_impl = GeLUFunction.apply \ No newline at end of file diff --git a/colossalai/nn/layer/parallel_1d/__init__.py b/colossalai/nn/layer/parallel_1d/__init__.py index 8fcd82aab..6f2093a11 100644 --- a/colossalai/nn/layer/parallel_1d/__init__.py +++ b/colossalai/nn/layer/parallel_1d/__init__.py @@ -1,4 +1,4 @@ -from .layers import Linear1D_Col, Linear1D_Row +from .layers import Dropout1D, Embedding1D, Linear1D, Linear1D_Col, Linear1D_Row from .layers import MixedFusedLayerNorm1D as LayerNorm1D -__all__ = ['Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D'] +__all__ = ['Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D', 'Embedding1D', 'Dropout1D'] diff --git a/colossalai/nn/layer/parallel_1d/_utils.py b/colossalai/nn/layer/parallel_1d/_utils.py index b8b7bcceb..db589afe5 100644 --- a/colossalai/nn/layer/parallel_1d/_utils.py +++ b/colossalai/nn/layer/parallel_1d/_utils.py @@ -1,12 +1,21 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import os import torch import torch.distributed as dist - +from colossalai.constants import PARALLEL_INPUT_1D from colossalai.core import global_context as gpc -from .._common_utils import divide +from ..utils import divide + + +def set_parallel_input(input_parallel: bool): + os.environ[PARALLEL_INPUT_1D] = 'true' if input_parallel else '' + + +def get_parallel_input(): + return bool(os.environ[PARALLEL_INPUT_1D]) def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank): diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index 21764aca6..3a3fa6e00 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -3,10 +3,10 @@ import math import numbers +from contextlib import nullcontext from typing import Callable, Tuple import torch -import torch.distributed as dist import torch.nn.functional as F from colossalai.communication import broadcast from colossalai.context import ParallelMode, seed @@ -14,13 +14,122 @@ from colossalai.core import global_context as gpc from colossalai.nn import init as init from colossalai.registry import LAYERS from colossalai.utils import get_current_device -from torch import Tensor +from torch import Tensor, dtype from torch.nn.parameter import Parameter -from .._common_utils import divide, set_tensor_parallel_attribute_by_partition from ..base_layer import ParallelLayer +from ..utils import divide, set_tensor_parallel_attribute_by_partition from ._operation import FusedLayerNormAffineFunction1D -from ._utils import (gather_forward_split_backward, reduce_grad, reduce_input, split_forward_gather_backward) +from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input, + split_forward_gather_backward) + + +@LAYERS.register_module +class Linear1D(torch.nn.Module): + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + parallel_input = get_parallel_input() + if not parallel_input: + self.layer = Linear1D_Col(in_features, + out_features, + bias=bias, + dtype=dtype, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer) + else: + self.layer = Linear1D_Row(in_features, + out_features, + bias=bias, + dtype=dtype, + parallel_input=parallel_input, + skip_bias_add=skip_bias_add, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer) + + @property + def weight(self): + return self.layer.weight + + @property + def bias(self): + return self.layer.bias + + def forward(self, input_: Tensor) -> Tensor: + return self.layer(input_) + + +@LAYERS.register_module +class Classifier1D(ParallelLayer): + """RowLinear with given weight""" + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + self.parallel_input = get_parallel_input() + + # Divide the weight matrix along the last dimension. + self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter(torch.empty(self.num_classes, self.input_size_per_partition, **factory_kwargs)) + self.has_weight = True + if bias: + self.bias = Parameter(torch.empty(self.num_classes, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + input_ = input_ + else: + input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) + + output_parallel = F.linear(input_, self.weight) + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + + output = output + self.bias + return output @LAYERS.register_module @@ -77,6 +186,7 @@ class Linear1D_Col(ParallelLayer): with seed(ParallelMode.TENSOR): self.reset_parameters(weight_initializer, bias_initializer) self._set_tensor_parallel_attributes() + set_parallel_input(True) def reset_parameters(self, weight_initializer, bias_initializer) -> None: fan_in, fan_out = self.in_features, self.out_features @@ -158,6 +268,7 @@ class Linear1D_Row(ParallelLayer): with seed(ParallelMode.TENSOR): self.reset_parameters(weight_initializer, bias_initializer) self._set_tensor_parallel_attributes() + set_parallel_input(False) def reset_parameters(self, weight_initializer, bias_initializer) -> None: fan_in, fan_out = self.in_features, self.out_features @@ -208,3 +319,68 @@ class MixedFusedLayerNorm1D(torch.nn.Module): def forward(self, input): return FusedLayerNormAffineFunction1D.apply(input, self.weight, self.bias, self.normalized_shape, self.eps) + + +@LAYERS.register_module +class Embedding1D(ParallelLayer): + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size) + + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + self.weight = Parameter( + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input_: Tensor) -> Tensor: + + output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + + return output + + +@LAYERS.register_module +class Dropout1D(ParallelLayer): + def __init__(self, p: float = 0.5, inplace: bool = False): + super().__init__() + self.parallel_input = get_parallel_input() + self.p = p + self.inplace = inplace + + def forward(self, input_: Tensor) -> Tensor: + cm = nullcontext() if not self.parallel_input else seed(ParallelMode.TENSOR) + with cm: + output = F.dropout(input_, self.p, self.training, self.inplace) + return output diff --git a/colossalai/nn/layer/parallel_2d/__init__.py b/colossalai/nn/layer/parallel_2d/__init__.py index e54f3e7e4..2122a1bfe 100644 --- a/colossalai/nn/layer/parallel_2d/__init__.py +++ b/colossalai/nn/layer/parallel_2d/__init__.py @@ -1,6 +1,6 @@ -from ._operation import reduce_by_batch_2d, split_batch_2d +from ._operation import reduce_by_batch_2d, split_tensor_2d from .layers import Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D __all__ = [ - 'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D' + 'split_tensor_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D' ] diff --git a/colossalai/nn/layer/parallel_2d/_operation.py b/colossalai/nn/layer/parallel_2d/_operation.py index 603b4dcfe..9955bcefe 100644 --- a/colossalai/nn/layer/parallel_2d/_operation.py +++ b/colossalai/nn/layer/parallel_2d/_operation.py @@ -2,7 +2,7 @@ from typing import Any, Optional, Tuple import torch import torch.distributed as dist -from colossalai.communication.collective import (all_gather, all_reduce, reduce_scatter) +from colossalai.communication.collective import (all_gather, all_reduce, reduce, reduce_scatter) from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils import get_current_device @@ -595,7 +595,9 @@ class SplitFirst(torch.autograd.Function): return grad, None, None -def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor: +def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor: + if input_.size(dim) <= 1: + return input_ return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL), dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous() @@ -603,17 +605,28 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor: class reduce_by_batch_2d(torch.autograd.Function): """All-reduce the input from the model parallel region.""" @staticmethod - def symbolic(graph, input_): - dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL)) - return input_ + def symbolic(graph, input_, reduce_mean: bool = False): + output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) + if reduce_mean: + reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL) + return output / reduce_size + return output @staticmethod @custom_fwd(cast_inputs=torch.float32) - def forward(ctx, input_): - dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL)) - return input_.clone() + def forward(ctx, input_, reduce_mean: bool = False): + output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) + ctx.reduce_mean = reduce_mean + if reduce_mean: + reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL) + ctx.reduce_size = reduce_size + return output.clone() / reduce_size + return output.clone() @staticmethod @custom_bwd - def backward(ctx, grad_output): - return grad_output + def backward(ctx, output_grad): + if ctx.reduce_mean: + return output_grad / ctx.reduce_size, None + else: + return output_grad, None diff --git a/colossalai/nn/layer/parallel_2d/layers.py b/colossalai/nn/layer/parallel_2d/layers.py index 5b735aca5..d113ec94c 100644 --- a/colossalai/nn/layer/parallel_2d/layers.py +++ b/colossalai/nn/layer/parallel_2d/layers.py @@ -13,9 +13,9 @@ from colossalai.utils import get_current_device from torch import Tensor, dtype from torch.nn import Parameter -from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple) +from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..base_layer import ParallelLayer -from ._operation import (Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d, split_batch_2d) +from ._operation import Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d from ._utils import assert_summa_initialization, get_summa_dim_from_env @@ -257,8 +257,6 @@ class PatchEmbedding2D(ParallelLayer): assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - input_ = split_batch_2d(input_) - weight = all_gather_weight_2d.apply(self.weight, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL) bias = all_gather_weight_2d.apply(self.bias, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL) @@ -318,8 +316,6 @@ class Embedding2D(ParallelLayer): self.weight[self.padding_idx].fill_(0) def forward(self, input_: Tensor) -> Tensor: - input_ = split_batch_2d(input_) - weight = all_gather_weight_2d.apply(self.weight, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) diff --git a/colossalai/nn/layer/parallel_2p5d/__init__.py b/colossalai/nn/layer/parallel_2p5d/__init__.py index 5fc9666f8..202c948c5 100644 --- a/colossalai/nn/layer/parallel_2p5d/__init__.py +++ b/colossalai/nn/layer/parallel_2p5d/__init__.py @@ -1,7 +1,7 @@ -from ._operation import reduce_by_batch_2p5d, split_batch_2p5d +from ._operation import reduce_by_batch_2p5d, split_tensor_2p5d from .layers import Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D __all__ = [ - 'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D', + 'split_tensor_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D', 'Embedding2p5D' ] diff --git a/colossalai/nn/layer/parallel_2p5d/_operation.py b/colossalai/nn/layer/parallel_2p5d/_operation.py index 5a38c5d37..a1dbcd3cd 100644 --- a/colossalai/nn/layer/parallel_2p5d/_operation.py +++ b/colossalai/nn/layer/parallel_2p5d/_operation.py @@ -22,7 +22,7 @@ def get_parallel_rank(parallel_mode: ParallelMode): return gpc.get_local_rank(parallel_mode) -def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor: +def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor: return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous() @@ -120,30 +120,53 @@ class Matmul_AB_2p5D(torch.autograd.Function): ctx.save_for_backward(A, B) A_shape = A.shape - A = A.reshape((-1, A_shape[-1])).contiguous() + A = A.reshape((-1, A_shape[-1])) B_shape = B.shape - B = B.reshape((-1, B_shape[-1])).contiguous() + B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[0], B.shape[-1]) C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) - A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode) - 1)] - B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode) - 1)] - A_list.insert(gpc.get_local_rank(row_parallel_mode), A) - B_list.insert(gpc.get_local_rank(col_parallel_mode), B) - op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True) - op_a.wait() - op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True) - for op in [op_a, op_b]: - op.wait() + # use circular buffer to store the communication tensor + # 2 is enough for all cases + A_list = [torch.empty_like(A) for _ in range(2)] + B_list = [torch.empty_like(B) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + + opa = [None] * 2 + opb = [None] * 2 + + A_list[0].copy_(A) + B_list[0].copy_(B) + opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) + opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) + cur = 0 for i in range(tesseract_dim): - src_a = i + tesseract_dim * row_rank - src_b = i + tesseract_dim * col_rank - src_a = src_a % tesseract_dim - src_b = src_b % tesseract_dim - A_temp = A_list[src_a] - B_temp = B_list[src_b] - torch.addmm(C, A_temp, B_temp, out=C) + if i != tesseract_dim - 1: + A_list[1 - cur].copy_(A) + opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) + B_list[1 - cur].copy_(B) + opb[1 - cur] = dist.broadcast(B_list[1 - cur], + src=src_b + tesseract_dim, + group=col_group, + async_op=True) + + if opa[cur] is not None: + opa[cur].wait() + if opb[cur] is not None: + opb[cur].wait() + + torch.addmm(C, A_list[cur], B_list[cur], out=C) + cur = 1 - cur + src_a += 1 + src_b += tesseract_dim out = C.reshape(out_shape) if ctx: @@ -201,20 +224,55 @@ class Matmul_ABT_2p5D(torch.autograd.Function): C_shape = (A.shape[0], B.shape[0]) C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) - for i in range(tesseract_dim): - B_temp = B.clone() - src_b = col_rank + i * tesseract_dim + dep_rank * ( - tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - dist.broadcast(B_temp, src=src_b, group=gpc.get_group(col_parallel_mode)) - C_temp = torch.matmul(A, B_temp.transpose(0, 1)) - src_c = i + row_rank * tesseract_dim + dep_rank * ( - tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - dist.reduce(C_temp, dst=src_c, group=gpc.get_group(row_parallel_mode)) - if i == col_rank: - C = C_temp.clone() + # use circular buffer to store the communication tensor + # 2 is enough for all cases + B_list = [torch.empty_like(B) for _ in range(2)] + C_list = [torch.empty_like(C) for _ in range(2)] + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_c = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + + opb = [None] * 2 + opr = [None] * 2 + + B_list[0].copy_(B) + opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) + cur = 0 + + for i in range(tesseract_dim): + if i != tesseract_dim - 1: + B_list[1 - cur].copy_(B) + opb[1 - cur] = dist.broadcast(B_list[1 - cur], + src=src_b + tesseract_dim, + group=col_group, + async_op=True) + + if opr[cur] is not None: + opr[cur].wait() + if i - 2 == col_rank: + C.copy_(C_list[cur]) + + if opb[cur] is not None: + opb[cur].wait() + + torch.matmul(A, B_list[cur].transpose(0, 1), out=C_list[cur]) + opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=row_group, async_op=True) + cur = 1 - cur + src_b += tesseract_dim + src_c += 1 + + for op in opr: + op.wait() + + if tesseract_dim - 2 == col_rank: + C.copy_(C_list[cur]) + if tesseract_dim - 1 == col_rank: + C.copy_(C_list[1 - cur]) out = C.reshape(out_shape) if ctx: @@ -272,20 +330,52 @@ class Matmul_ATB_2p5D(torch.autograd.Function): C_shape = (A.shape[-1], B.shape[-1]) C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) - for i in range(tesseract_dim): - A_temp = A.clone() - src_a = i + row_rank * tesseract_dim + dep_rank * ( - tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - dist.broadcast(A_temp, src=src_a, group=get_parallel_group(row_parallel_mode)) - C_temp = torch.matmul(A_temp.transpose(0, 1), B) - src_c = col_rank + i * tesseract_dim + dep_rank * ( - tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - dist.reduce(C_temp, dst=src_c, group=get_parallel_group(col_parallel_mode)) - if i == row_rank: - C = C_temp.clone() + # use circular buffer to store the communication tensor + # 2 is enough for all cases + A_list = [torch.empty_like(A) for _ in range(2)] + C_list = [torch.empty_like(C) for _ in range(2)] + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_c = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + + opa = [None] * 2 + opr = [None] * 2 + + A_list[0].copy_(A) + opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) + cur = 0 + + for i in range(tesseract_dim): + if i != tesseract_dim - 1: + A_list[1 - cur].copy_(A) + opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) + + if opr[cur] is not None: + opr[cur].wait() + if i - 2 == row_rank: + C.copy_(C_list[cur]) + + if opa[cur] is not None: + opa[cur].wait() + + torch.matmul(A_list[cur].transpose(0, 1), B, out=C_list[cur]) + opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=col_group, async_op=True) + cur = 1 - cur + src_a += 1 + src_c += tesseract_dim + + for op in opr: + op.wait() + + if tesseract_dim - 2 == row_rank: + C.copy_(C_list[cur]) + if tesseract_dim - 1 == row_rank: + C.copy_(C_list[1 - cur]) out = C.reshape(out_shape) if ctx: @@ -333,8 +423,7 @@ class Add_Bias_2p5D(torch.autograd.Function): bias_temp = bias.clone() else: bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device()) - src_rank = col_rank + dep_rank * ( - tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + src_rank = col_rank + dep_rank * tesseract_dim ** 2 + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ pipeline_parallel_rank * tensor_parallel_size dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode)) @@ -469,7 +558,9 @@ class SplitFirst(torch.autograd.Function): return grad, None, None -def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor: +def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor: + if input_.size(dim) <= 1: + return input_ return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous() @@ -477,17 +568,28 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor: class reduce_by_batch_2p5d(torch.autograd.Function): """All-reduce the input from the model parallel region.""" @staticmethod - def symbolic(graph, input_): - dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL)) - return input_ + def symbolic(graph, input_, reduce_mean: bool = False): + output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL) + if reduce_mean: + reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL) + return output / reduce_size + return output @staticmethod @custom_fwd(cast_inputs=torch.float32) - def forward(ctx, input_): - dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL)) - return input_.clone() + def forward(ctx, input_, reduce_mean: bool = False): + output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL) + ctx.reduce_mean = reduce_mean + if reduce_mean: + reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL) + ctx.reduce_size = reduce_size + return output.clone() / reduce_size + return output.clone() @staticmethod @custom_bwd - def backward(ctx, grad_output): - return grad_output + def backward(ctx, output_grad): + if ctx.reduce_mean: + return output_grad / ctx.reduce_size, None + else: + return output_grad, None diff --git a/colossalai/nn/layer/parallel_2p5d/layers.py b/colossalai/nn/layer/parallel_2p5d/layers.py index 963a1e8b2..d7bd265bd 100644 --- a/colossalai/nn/layer/parallel_2p5d/layers.py +++ b/colossalai/nn/layer/parallel_2p5d/layers.py @@ -13,10 +13,9 @@ from colossalai.utils import get_current_device from torch import Tensor, dtype from torch.nn import Parameter -from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple) from ..base_layer import ParallelLayer -from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d, - split_batch_2p5d) +from ..utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple) +from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d) from ._utils import (assert_tesseract_initialization, get_tesseract_dim_dep_from_env) @@ -231,7 +230,7 @@ class PatchEmbedding2p5D(ParallelLayer): self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.embed_size = embed_size - self.embed_size_per_partition = embed_size // (self.tesseract_dep * self.tesseract_dim**2) + self.embed_size_per_partition = embed_size // self.tesseract_dim**2 with seed(ParallelMode.TENSOR): self.weight = Parameter( @@ -251,10 +250,10 @@ class PatchEmbedding2p5D(ParallelLayer): self._set_tensor_parallel_attribute() def _set_tensor_parallel_attribute(self): - set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2) - set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dep * self.tesseract_dim**2) - set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dep * self.tesseract_dim**2) - set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dep * self.tesseract_dim**2) + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim**2) + set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dim**2) + set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dim**2) def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): with seed(ParallelMode.TENSOR): @@ -269,8 +268,6 @@ class PatchEmbedding2p5D(ParallelLayer): assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - input_ = split_batch_2p5d(input_) - weight = all_gather_weight_2p5d.apply(self.weight, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) bias = all_gather_weight_2p5d.apply(self.bias, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) @@ -303,7 +300,7 @@ class Embedding2p5D(ParallelLayer): self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() self.num_embeddings = num_embeddings self.embed_dim = embedding_dim - embed_dim_per_partition = embedding_dim // (self.tesseract_dep * self.tesseract_dim**2) + embed_dim_per_partition = embedding_dim // self.tesseract_dim**2 self.padding_idx = padding_idx self.embed_args = args @@ -316,7 +313,7 @@ class Embedding2p5D(ParallelLayer): self._set_tensor_parallel_attributes() def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2) + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) def reset_parameters(self, weight_initializer) -> None: with seed(ParallelMode.TENSOR): @@ -330,8 +327,6 @@ class Embedding2p5D(ParallelLayer): self.weight[self.padding_idx].fill_(0) def forward(self, input_: Tensor) -> Tensor: - input_ = split_batch_2p5d(input_) - weight = all_gather_weight_2p5d.apply(self.weight, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) @@ -359,7 +354,7 @@ class Classifier2p5D(ParallelLayer): self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() # partitioning dimension - self.input_size_per_partition = divide(self.in_features, self.tesseract_dep * self.tesseract_dim**2) + self.input_size_per_partition = divide(self.in_features, self.tesseract_dim**2) if weight is not None: self.weight = weight @@ -378,7 +373,7 @@ class Classifier2p5D(ParallelLayer): def _set_tensor_parallel_attributes(self): if self.has_weight: - set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2) + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) def reset_parameters(self, weight_initializer, bias_initializer) -> None: with seed(ParallelMode.TENSOR): diff --git a/colossalai/nn/layer/parallel_3d/__init__.py b/colossalai/nn/layer/parallel_3d/__init__.py index feb30d462..46eeacda1 100644 --- a/colossalai/nn/layer/parallel_3d/__init__.py +++ b/colossalai/nn/layer/parallel_3d/__init__.py @@ -1,6 +1,6 @@ -from ._operation import reduce_by_batch_3d, split_batch_3d +from ._operation import reduce_by_batch_3d, split_tensor_3d from .layers import Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D __all__ = [ - 'reduce_by_batch_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D' + 'reduce_by_batch_3d', 'split_tensor_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D' ] diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/nn/layer/parallel_3d/_operation.py index 5b3763c3a..96ed775ec 100644 --- a/colossalai/nn/layer/parallel_3d/_operation.py +++ b/colossalai/nn/layer/parallel_3d/_operation.py @@ -175,10 +175,12 @@ class layernorm_3d(torch.autograd.Function): return input_grad, weight_grad, bias_grad, None, None, None, None, None -def split_batch_3d(input_: Tensor, - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, - dim: int = 0) -> Tensor: +def split_tensor_3d(input_: Tensor, + dim: int = 0, + input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT, + weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor: + if input_.size(dim) <= 1: + return input_ output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode), dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous() output = torch.chunk(output, gpc.get_world_size(input_parallel_mode), @@ -189,15 +191,27 @@ def split_batch_3d(input_: Tensor, class reduce_by_batch_3d(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) - def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode) -> Tensor: + def forward(ctx, + input_: Tensor, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + reduce_mean: bool = False) -> Tensor: output = all_reduce(input_, input_parallel_mode) output = all_reduce(output, weight_parallel_mode) + ctx.reduce_mean = reduce_mean + if reduce_mean: + reduce_size = gpc.get_world_size(input_parallel_mode) * gpc.get_world_size(weight_parallel_mode) + ctx.reduce_size = reduce_size + return output.clone() / reduce_size return output.clone() @staticmethod @custom_bwd def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: - return output_grad, None, None + if ctx.reduce_mean: + return output_grad / ctx.reduce_size, None, None, None + else: + return output_grad, None, None, None class broadcast_weight_3d_from_diagonal(torch.autograd.Function): diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py index 59b449828..4871d1443 100644 --- a/colossalai/nn/layer/parallel_3d/layers.py +++ b/colossalai/nn/layer/parallel_3d/layers.py @@ -17,9 +17,9 @@ from colossalai.utils import get_current_device from torch import Tensor, dtype from torch.nn import Parameter -from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple) +from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ._operation import * -from ._utils import (get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group) +from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group @LAYERS.register_module @@ -241,8 +241,6 @@ class PatchEmbedding3D(ParallelLayer): self.pos_embed.register_hook(self._sync_grad_hook) def forward(self, input_: Tensor) -> Tensor: - input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode) - weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode) output = F.conv2d(input_, weight, self.bias, stride=self.patch_size) @@ -302,8 +300,6 @@ class Embedding3D(ParallelLayer): self.weight[self.padding_idx].fill_(0) def forward(self, input_: Tensor) -> Tensor: - input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode) - weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) diff --git a/colossalai/nn/layer/utils/__init__.py b/colossalai/nn/layer/utils/__init__.py new file mode 100644 index 000000000..7e999ee82 --- /dev/null +++ b/colossalai/nn/layer/utils/__init__.py @@ -0,0 +1,7 @@ +from .common import (ACT2FN, CheckpointModule, _ntuple, divide, get_tensor_parallel_mode, + set_tensor_parallel_attribute_by_partition, set_tensor_parallel_attribute_by_size, to_2tuple) + +__all__ = [ + 'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size', + 'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple' +] diff --git a/colossalai/nn/layer/_common_utils.py b/colossalai/nn/layer/utils/common.py similarity index 91% rename from colossalai/nn/layer/_common_utils.py rename to colossalai/nn/layer/utils/common.py index d38e74f95..734aa5bfa 100644 --- a/colossalai/nn/layer/_common_utils.py +++ b/colossalai/nn/layer/utils/common.py @@ -2,11 +2,12 @@ # -*- encoding: utf-8 -*- import collections.abc +import os from itertools import repeat import numpy as np import torch -from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS +from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_MODE) from colossalai.utils import checkpoint from torch import Tensor, nn @@ -59,6 +60,10 @@ def set_tensor_parallel_attribute_by_partition(param, num_partitions): setattr(param, NUM_PARTITIONS, num_partitions) +def get_tensor_parallel_mode(): + return os.environ[TENSOR_PARALLEL_MODE] + + # From PyTorch internals diff --git a/colossalai/nn/layer/vanilla/layers.py b/colossalai/nn/layer/vanilla/layers.py index f19cca475..a89e5e1e9 100644 --- a/colossalai/nn/layer/vanilla/layers.py +++ b/colossalai/nn/layer/vanilla/layers.py @@ -9,7 +9,7 @@ from colossalai.utils import get_current_device from torch import Tensor, dtype from torch import nn as nn -from .._common_utils import to_2tuple +from ..utils import to_2tuple def drop_path(x, drop_prob: float = 0., training: bool = False): diff --git a/colossalai/nn/loss/__init__.py b/colossalai/nn/loss/__init__.py index 58a9d625a..65eef4a9e 100644 --- a/colossalai/nn/loss/__init__.py +++ b/colossalai/nn/loss/__init__.py @@ -2,6 +2,7 @@ from torch import nn from torch.nn.modules.loss import * from torch.nn.modules.loss import _Loss +from colossalai.nn.layer.utils import get_tensor_parallel_mode from .loss_2d import CrossEntropyLoss2D from .loss_2p5d import CrossEntropyLoss2p5D from .loss_3d import CrossEntropyLoss3D @@ -14,9 +15,10 @@ _parallel_cross_entropy = { class CrossEntropyLoss(_Loss): - def __init__(self, reduction: bool = True, tensor_parallel: str = None, *args, **kwargs): + def __init__(self, reduction: bool = True, *args, **kwargs): super().__init__() - if tensor_parallel in [None, '1d']: + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel in ['None', '1d']: reduction = 'mean' if reduction else 'none' self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs) else: diff --git a/colossalai/nn/loss/loss_2d.py b/colossalai/nn/loss/loss_2d.py index aeb798201..7aef949f6 100644 --- a/colossalai/nn/loss/loss_2d.py +++ b/colossalai/nn/loss/loss_2d.py @@ -1,4 +1,4 @@ -from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d +from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization from colossalai.registry import LOSSES from torch.nn.functional import cross_entropy @@ -20,11 +20,8 @@ class CrossEntropyLoss2D(_Loss): self.loss_kwargs = kwargs def forward(self, logits, targets): - batch_size = targets.size(0) - targets = split_batch_2d(targets) - loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs) + loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) if self.reduction_mean: - loss = loss.sum() - loss = reduce_by_batch_2d.apply(loss) - loss /= batch_size + loss = loss.mean() + loss = reduce_by_batch_2d.apply(loss, True) return loss diff --git a/colossalai/nn/loss/loss_2p5d.py b/colossalai/nn/loss/loss_2p5d.py index 4f11b7175..d7596d924 100644 --- a/colossalai/nn/loss/loss_2p5d.py +++ b/colossalai/nn/loss/loss_2p5d.py @@ -1,4 +1,4 @@ -from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d +from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization from colossalai.registry import LOSSES from torch.nn.functional import cross_entropy @@ -19,11 +19,8 @@ class CrossEntropyLoss2p5D(_Loss): self.loss_kwargs = kwargs def forward(self, logits, targets): - batch_size = targets.size(0) - targets = split_batch_2p5d(targets) - loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs) + loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) if self.reduction_mean: - loss = loss.sum() - loss = reduce_by_batch_2p5d.apply(loss) - loss /= batch_size + loss = loss.mean() + loss = reduce_by_batch_2p5d.apply(loss, True) return loss diff --git a/colossalai/nn/loss/loss_3d.py b/colossalai/nn/loss/loss_3d.py index d5431dabc..59b6ffeeb 100644 --- a/colossalai/nn/loss/loss_3d.py +++ b/colossalai/nn/loss/loss_3d.py @@ -1,11 +1,10 @@ from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D -from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_batch_3d +from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.registry import LOSSES from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss - @LOSSES.register_module class CrossEntropyLoss3D(_Loss): """Cross entropy loss for 3D parallelism @@ -28,11 +27,8 @@ class CrossEntropyLoss3D(_Loss): self.loss_kwargs = kwargs def forward(self, logits, targets): - batch_size = targets.size(0) - targets = split_batch_3d(targets, self.input_parallel_mode, self.weight_parallel_mode) - loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs) + loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) if self.reduction_mean: - loss = loss.sum() - loss = reduce_by_batch_3d.apply(loss, self.input_parallel_mode, self.weight_parallel_mode) - loss /= batch_size + loss = loss.mean() + loss = reduce_by_batch_3d.apply(loss, self.input_parallel_mode, self.weight_parallel_mode, True) return loss diff --git a/colossalai/nn/metric/__init__.py b/colossalai/nn/metric/__init__.py index 036bcaa69..7ce17b08b 100644 --- a/colossalai/nn/metric/__init__.py +++ b/colossalai/nn/metric/__init__.py @@ -4,6 +4,7 @@ from ._utils import calc_acc from .accuracy_2d import Accuracy2D from .accuracy_2p5d import Accuracy2p5D from .accuracy_3d import Accuracy3D +from colossalai.nn.layer.utils import get_tensor_parallel_mode _parallel_accuracy = { '2d': Accuracy2D, @@ -13,9 +14,10 @@ _parallel_accuracy = { class Accuracy(nn.Module): - def __init__(self, tensor_parallel: str = None): + def __init__(self): super().__init__() - if tensor_parallel in [None, '1d']: + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel in ['None', '1d']: self.acc = calc_acc else: self.acc = _parallel_accuracy[tensor_parallel]() diff --git a/colossalai/nn/metric/accuracy_2d.py b/colossalai/nn/metric/accuracy_2d.py index 1026a52e2..cc207b02c 100644 --- a/colossalai/nn/metric/accuracy_2d.py +++ b/colossalai/nn/metric/accuracy_2d.py @@ -1,5 +1,5 @@ import torch -from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d +from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d from torch import nn from ._utils import calc_acc @@ -11,7 +11,6 @@ class Accuracy2D(nn.Module): def forward(self, logits, targets): with torch.no_grad(): - targets = split_batch_2d(targets) correct = calc_acc(logits, targets) correct = reduce_by_batch_2d.apply(correct) return correct diff --git a/colossalai/nn/metric/accuracy_2p5d.py b/colossalai/nn/metric/accuracy_2p5d.py index 98373cbfb..90dc4af26 100644 --- a/colossalai/nn/metric/accuracy_2p5d.py +++ b/colossalai/nn/metric/accuracy_2p5d.py @@ -1,5 +1,5 @@ import torch -from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d +from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d from torch import nn from ._utils import calc_acc @@ -11,7 +11,6 @@ class Accuracy2p5D(nn.Module): def forward(self, logits, targets): with torch.no_grad(): - targets = split_batch_2p5d(targets) correct = calc_acc(logits, targets) correct = reduce_by_batch_2p5d.apply(correct) return correct diff --git a/colossalai/nn/metric/accuracy_3d.py b/colossalai/nn/metric/accuracy_3d.py index f717b9fb2..576800510 100644 --- a/colossalai/nn/metric/accuracy_3d.py +++ b/colossalai/nn/metric/accuracy_3d.py @@ -1,6 +1,6 @@ import torch from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D -from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_batch_3d +from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from torch import nn @@ -15,7 +15,6 @@ class Accuracy3D(nn.Module): def forward(self, logits, targets): with torch.no_grad(): - targets = split_batch_3d(targets, self.input_parallel_mode, self.weight_parallel_mode) correct = calc_acc(logits, targets) correct = reduce_by_batch_3d.apply(correct, self.input_parallel_mode, self.weight_parallel_mode) return correct diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/trainer/hooks/_metric_hook.py index bbf66a6fd..348929880 100644 --- a/colossalai/trainer/hooks/_metric_hook.py +++ b/colossalai/trainer/hooks/_metric_hook.py @@ -173,7 +173,7 @@ class AccuracyMetric(Metric): self.accumulated_sum.zero_() self.accumulated_correct.zero_() - def update(self, logits, targets) -> None: + def update(self, logits, targets, batch_size) -> None: """Updates last step accuracy and accumulated accuracy with current logits and labels. It expects the output has logits and labels. @@ -187,7 +187,7 @@ class AccuracyMetric(Metric): # update correct = self.acc(logits, targets) - self.last_step_sum.fill_(targets.size(0)) + self.last_step_sum.fill_(batch_size) self.last_step_correct.fill_(correct) self.accumulated_sum += self.last_step_sum self.accumulated_correct += self.last_step_correct @@ -296,7 +296,8 @@ class AccuracyHook(MetricHook): def after_test_iter(self, trainer, logits, targets, *args): if self._is_stage_to_compute: - self.metric.update(logits, targets) + batch_size = trainer.schedule.batch_size + self.metric.update(logits, targets, batch_size) class ThroughputMetric(Metric): @@ -313,10 +314,8 @@ class ThroughputMetric(Metric): self.last_step_num_samples.zero_() self.last_step_used_time.zero_() - def update(self, tensor, time) -> None: - if isinstance(tensor, (list, tuple)): - tensor = tensor[0] - self.last_step_num_samples.fill_(tensor.size(0)) + def update(self, num_samples, time) -> None: + self.last_step_num_samples.fill_(num_samples) self.last_step_used_time.fill_(time) self.accumulated_num_samples += self.last_step_num_samples self.accumulated_used_time += self.last_step_used_time @@ -354,11 +353,11 @@ class ThroughputHook(MetricHook): def before_train_epoch(self, trainer): self.metric.reset() - def after_train_iter(self, trainer, logits, targets, *args): - self.metric.update(targets, trainer._timer.get_timer('Train-step').get_elapsed_time()) + def after_train_iter(self, trainer, *args): + self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time()) def before_test(self, trainer): self.metric.reset() - def after_test_iter(self, trainer, logits, targets, *args): - self.metric.update(targets, trainer._timer.get_timer('Test-step').get_elapsed_time()) + def after_test_iter(self, trainer, *args): + self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time()) diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 7430ab100..b346f1a57 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -1,27 +1,19 @@ from .activation_checkpoint import checkpoint -from .common import (print_rank_0, sync_model_param_in_dp, is_dp_rank_0, - is_tp_rank_0, is_no_pp_or_last_stage, is_using_ddp, - is_using_pp, conditional_context, is_model_parallel_parameter, - clip_grad_norm_fp32, count_zeros_fp32, copy_tensor_parallel_attributes, - param_is_not_tensor_parallel_duplicate, switch_virtual_pipeline_parallel_rank) -from .cuda import get_current_device, synchronize, empty_cache, set_to_cuda +from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32, + free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0, + is_using_ddp, is_using_pp, multi_tensor_applier, param_is_not_tensor_parallel_duplicate, + print_rank_0, switch_virtual_pipeline_parallel_rank, sync_model_param_in_dp) +from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize +from .data_sampler import DataParallelSampler, get_dataloader +from .gradient_accumulation import accumulate_gradient from .memory import report_memory_usage from .timer import MultiTimer, Timer -from .multi_tensor_apply import multi_tensor_applier -from .gradient_accumulation import accumulate_gradient -from .data_sampler import DataParallelSampler, get_dataloader -__all__ = ['checkpoint', - 'print_rank_0', 'sync_model_param_in_dp', 'is_dp_rank_0', - 'is_tp_rank_0', 'is_no_pp_or_last_stage', 'is_using_ddp', - 'is_using_pp', 'conditional_context', 'is_model_parallel_parameter', - 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes', - 'param_is_not_tensor_parallel_duplicate', - 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda', - 'report_memory_usage', - 'Timer', 'MultiTimer', - 'multi_tensor_applier', - 'accumulate_gradient', - 'DataParallelSampler', 'get_dataloader', - 'switch_virtual_pipeline_parallel_rank' - ] +__all__ = [ + 'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param_in_dp', 'is_dp_rank_0', 'is_tp_rank_0', + 'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'conditional_context', 'is_model_parallel_parameter', + 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes', + 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda', + 'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler', + 'get_dataloader', 'switch_virtual_pipeline_parallel_rank' +] diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 610986d03..6e3318172 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -1,5 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import random +import socket import torch from torch._six import inf @@ -9,16 +11,15 @@ try: except: pass -import torch.distributed as dist from contextlib import contextmanager -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from .multi_tensor_apply import multi_tensor_applier -from colossalai.constants import IS_TENSOR_PARALLEL, TENSOR_PARALLEL_ATTRIBUTES, NUM_PARTITIONS + import torch.distributed as dist +from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc +from .multi_tensor_apply import multi_tensor_applier + def print_rank_0(msg: str, logger=None): '''Print messages and save logs(optional). This is executed only if you are the rank-0 gpu. @@ -33,6 +34,18 @@ def print_rank_0(msg: str, logger=None): logger.info(msg) +def free_port(): + while True: + try: + sock = socket.socket() + port = random.randint(20000, 65000) + sock.bind(('localhost', port)) + sock.close() + return port + except Exception: + continue + + def sync_model_param_in_dp(model): '''Make sure data parameters are consistent during Data Parallel Mode diff --git a/model_zoo/vit/vit.py b/model_zoo/vit/vit.py index 4e3209f2c..450f334a4 100644 --- a/model_zoo/vit/vit.py +++ b/model_zoo/vit/vit.py @@ -3,9 +3,8 @@ from typing import Callable import torch from colossalai import nn as col_nn -from colossalai.context import ParallelMode, seed +from colossalai.nn.layer.utils import CheckpointModule from colossalai.registry import LAYERS, MODELS -from colossalai.utils import checkpoint from torch import dtype, nn __all__ = [ @@ -72,8 +71,7 @@ class ViTEmbedding(nn.Module): dropout: float, dtype: dtype = None, flatten: bool = True, - init_method: str = 'torch', - tensor_parallel: str = None): + init_method: str = 'torch'): super().__init__() self.patch_embed = col_nn.PatchEmbedding(img_size, patch_size, @@ -81,19 +79,17 @@ class ViTEmbedding(nn.Module): embedding_dim, dtype=dtype, flatten=flatten, - tensor_parallel=tensor_parallel, **_init_rules[init_method]['embed']) - self.dropout = nn.Dropout(dropout) + self.dropout = col_nn.Dropout(dropout) def forward(self, x): x = self.patch_embed(x) - with seed(ParallelMode.TENSOR): - x = self.dropout(x) + x = self.dropout(x) return x @LAYERS.register_module -class ViTSelfAttention(nn.Module): +class ViTSelfAttention(CheckpointModule): def __init__(self, dim: int, num_heads: int, @@ -102,27 +98,17 @@ class ViTSelfAttention(nn.Module): bias: bool = True, dtype: dtype = None, checkpoint: bool = False, - init_method: str = 'torch', - tensor_parallel: str = None): - super().__init__() + init_method: str = 'torch'): + super().__init__(checkpoint) self.attention_head_size = dim // num_heads - self.checkpoint = checkpoint - self.tensor_parallel = tensor_parallel - self.query_key_value = col_nn.Linear(dim, 3 * dim, dtype=dtype, bias=bias, - tensor_parallel='1d_col' if tensor_parallel == '1d' else tensor_parallel, **_init_rules[init_method]['transformer']) - self.attention_dropout = nn.Dropout(attention_dropout) - self.dense = col_nn.Linear(dim, - dim, - dtype=dtype, - bias=True, - tensor_parallel='1d_row' if tensor_parallel == '1d' else tensor_parallel, - **_init_rules[init_method]['transformer']) - self.dropout = nn.Dropout(dropout) + self.attention_dropout = col_nn.Dropout(attention_dropout) + self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True, **_init_rules[init_method]['transformer']) + self.dropout = col_nn.Dropout(dropout) self.softmax = nn.Softmax(dim=-1) def _forward(self, x): @@ -138,8 +124,7 @@ class ViTSelfAttention(nn.Module): x = torch.matmul(q, k.transpose(-1, -2)) x = x / math.sqrt(self.attention_head_size) x = self.softmax(x) - with seed(ParallelMode.TENSOR): - x = self.attention_dropout(x) + x = self.attention_dropout(x) x = torch.matmul(x, v) x = x.transpose(1, 2) @@ -147,26 +132,13 @@ class ViTSelfAttention(nn.Module): x = x.reshape(new_context_layer_shape) x = self.dense(x) - if self.tensor_parallel == '1d': - x = self.dropout(x) - else: - with seed(ParallelMode.TENSOR): - x = self.dropout(x) + x = self.dropout(x) return x - def _checkpoint_forward(self, x): - return checkpoint(self._forward, x) - - def forward(self, x): - if self.checkpoint: - return self._checkpoint_forward(x) - else: - return self._forward(x) - @LAYERS.register_module -class ViTMLP(nn.Module): +class ViTMLP(CheckpointModule): def __init__(self, dim: int, mlp_ratio: int, @@ -175,50 +147,30 @@ class ViTMLP(nn.Module): dtype: dtype = None, bias: bool = True, checkpoint: bool = False, - init_method: str = 'torch', - tensor_parallel: str = None): - super().__init__() - self.checkpoint = checkpoint - self.tensor_parallel = tensor_parallel - + init_method: str = 'torch'): + super().__init__(checkpoint) self.dense_1 = col_nn.Linear(dim, mlp_ratio * dim, dtype=dtype, bias=bias, - tensor_parallel='1d_col' if tensor_parallel == '1d' else tensor_parallel, **_init_rules[init_method]['transformer']) self.activation = activation + self.dropout_1 = col_nn.Dropout(dropout) self.dense_2 = col_nn.Linear(mlp_ratio * dim, dim, dtype=dtype, bias=bias, - tensor_parallel='1d_row' if tensor_parallel == '1d' else tensor_parallel, **_init_rules[init_method]['transformer']) - self.dropout = nn.Dropout(dropout) + self.dropout_2 = col_nn.Dropout(dropout) def _forward(self, x): x = self.dense_1(x) x = self.activation(x) - with seed(ParallelMode.TENSOR): - x = self.dropout(x) + x = self.dropout_1(x) x = self.dense_2(x) - if self.tensor_parallel == '1d': - x = self.dropout(x) - else: - with seed(ParallelMode.TENSOR): - x = self.dropout(x) - + x = self.dropout_2(x) return x - def _checkpoint_forward(self, x): - return checkpoint(self._forward, x) - - def forward(self, x): - if self.checkpoint: - return self._checkpoint_forward(x) - else: - return self._forward(x) - @LAYERS.register_module class ViTHead(nn.Module): @@ -228,19 +180,14 @@ class ViTHead(nn.Module): representation_size: int = None, dtype: dtype = None, bias: bool = True, - init_method: str = 'torch', - tensor_parallel: str = None): + init_method: str = 'torch'): super().__init__() if representation_size: - tensor_parallel_kwargs = {'tensor_parallel': '1d_col' if tensor_parallel == '1d' else tensor_parallel} - if tensor_parallel == '1d': - tensor_parallel_kwargs['gather_output'] = True self.representation = col_nn.Linear(dim, representation_size, bias=bias, dtype=dtype, - **_init_rules[init_method]['head'], - **tensor_parallel_kwargs) + **_init_rules[init_method]['head']) else: self.representation = None representation_size = dim @@ -249,7 +196,6 @@ class ViTHead(nn.Module): num_classes, dtype=dtype, bias=bias, - tensor_parallel=tensor_parallel, **_init_rules[init_method]['head']) def forward(self, x): @@ -273,10 +219,9 @@ class ViTBlock(nn.Module): dtype: dtype = None, bias: bool = True, checkpoint: bool = False, - init_method: str = 'torch', - tensor_parallel: str = None): + init_method: str = 'torch'): super().__init__() - self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype, tensor_parallel=tensor_parallel) + self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) self.attn = ViTSelfAttention(dim=dim, num_heads=num_heads, attention_dropout=attention_dropout, @@ -284,10 +229,9 @@ class ViTBlock(nn.Module): bias=bias, dtype=dtype, checkpoint=checkpoint, - init_method=init_method, - tensor_parallel=tensor_parallel) + init_method=init_method) self.drop_path = col_nn.DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype, tensor_parallel=tensor_parallel) + self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) self.mlp = ViTMLP(dim=dim, mlp_ratio=mlp_ratio, activation=activation, @@ -295,8 +239,7 @@ class ViTBlock(nn.Module): dtype=dtype, bias=bias, checkpoint=checkpoint, - init_method=init_method, - tensor_parallel=tensor_parallel) + init_method=init_method) def forward(self, x): x = x + self.drop_path(self.attn(self.norm1(x))) @@ -323,20 +266,16 @@ class VisionTransformer(nn.Module): dtype: dtype = None, bias: bool = True, checkpoint: bool = False, - init_method: str = 'torch', - tensor_parallel: str = None): + init_method: str = 'torch'): super().__init__() - embed = ViTEmbedding( - img_size=img_size, - patch_size=patch_size, - in_chans=in_chans, - embedding_dim=dim, - dropout=dropout, - dtype=dtype, - init_method=init_method, - tensor_parallel=tensor_parallel, - ) + embed = ViTEmbedding(img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embedding_dim=dim, + dropout=dropout, + dtype=dtype, + init_method=init_method) # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] @@ -353,26 +292,17 @@ class VisionTransformer(nn.Module): bias=bias, checkpoint=checkpoint, init_method=init_method, - tensor_parallel=tensor_parallel, ) for i in range(depth) ] - norm = col_nn.LayerNorm( - normalized_shape=dim, - eps=1e-6, - dtype=dtype, - tensor_parallel=tensor_parallel, - ) + norm = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) - head = ViTHead( - dim=dim, - num_classes=num_classes, - representation_size=representation_size, - dtype=dtype, - bias=bias, - init_method=init_method, - tensor_parallel=tensor_parallel, - ) + head = ViTHead(dim=dim, + num_classes=num_classes, + representation_size=representation_size, + dtype=dtype, + bias=bias, + init_method=init_method) self.layers = nn.Sequential( embed, diff --git a/tests/test_comm/test_comm.py b/tests/test_comm/test_comm.py index e2f981af5..4316e1a56 100644 --- a/tests/test_comm/test_comm.py +++ b/tests/test_comm/test_comm.py @@ -1,4 +1,3 @@ -import time from functools import partial import pytest @@ -9,7 +8,7 @@ from colossalai.communication import all_gather, all_reduce, reduce_scatter from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch -from colossalai.utils import get_current_device +from colossalai.utils import free_port, get_current_device CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1))) @@ -49,8 +48,8 @@ def check_all_reduce(): torch.cuda.synchronize() -def check_layer(rank, world_size): - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=30010, backend='nccl') +def check_layer(rank, world_size, port): + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') assert dist.get_rank() == gpc.get_global_rank() print('Rank {} / {}'.format(dist.get_rank(), dist.get_world_size())) @@ -66,7 +65,7 @@ def check_layer(rank, world_size): @pytest.mark.dist def test_comm(): world_size = 4 - run_func = partial(check_layer, world_size=world_size) + run_func = partial(check_layer, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_context/test_2d_init.py b/tests/test_context/test_2d_init.py index 3ad376750..22826bf38 100644 --- a/tests/test_context/test_2d_init.py +++ b/tests/test_context/test_2d_init.py @@ -1,15 +1,16 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from functools import partial +from pathlib import Path + import pytest import torch import torch.multiprocessing as mp - from colossalai import launch from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from functools import partial -from pathlib import Path +from colossalai.utils import free_port CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2d_init.py').absolute() @@ -87,7 +88,7 @@ def test_2d_init(): test_fn = partial(init_2d, world_size=world_size, backend='gloo', - port='29900', + port=free_port(), host='localhost' ) mp.spawn(test_fn, nprocs=world_size) diff --git a/tests/test_context/test_2p5d_init.py b/tests/test_context/test_2p5d_init.py index 1ce5f8ff4..3668c701e 100644 --- a/tests/test_context/test_2p5d_init.py +++ b/tests/test_context/test_2p5d_init.py @@ -7,10 +7,10 @@ from pathlib import Path import pytest import torch import torch.multiprocessing as mp - from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch +from colossalai.utils import free_port CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2p5d_init.py').absolute() @@ -111,7 +111,7 @@ def test_2halfd_init(): test_fn = partial(init_2halfd, world_size=world_size, backend='gloo', - port='29901', + port=free_port(), host='localhost' ) mp.spawn(test_fn, nprocs=world_size) diff --git a/tests/test_context/test_3d_init.py b/tests/test_context/test_3d_init.py index 5c66ab6a0..c9395f868 100644 --- a/tests/test_context/test_3d_init.py +++ b/tests/test_context/test_3d_init.py @@ -7,11 +7,10 @@ from pathlib import Path import pytest import torch import torch.multiprocessing as mp - - from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch +from colossalai.utils import free_port CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_3d_init.py').absolute() @@ -104,7 +103,7 @@ def test_3d_init(): test_fn = partial(init_3d, world_size=world_size, backend='gloo', - port='29902', + port=free_port(), host='localhost' ) mp.spawn(test_fn, nprocs=world_size) diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py index 8fd8a6ae9..a472bf0ee 100644 --- a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py +++ b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py @@ -13,7 +13,7 @@ from colossalai.logging import get_dist_logger from colossalai.nn import Accuracy, LinearWarmupLR from colossalai.nn.loss import CrossEntropyLoss from colossalai.trainer import Trainer, hooks -from colossalai.utils import MultiTimer, get_dataloader +from colossalai.utils import MultiTimer, free_port, get_dataloader from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep from model_zoo.vit import vit_tiny_patch4_32 from torchvision import transforms @@ -27,12 +27,12 @@ CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')), gradient_accumulation=2) -def run_trainer(rank, world_size): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=30000, backend='nccl') +def run_trainer(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') logger = get_dist_logger() - model = vit_tiny_patch4_32(tensor_parallel='1d') + model = vit_tiny_patch4_32() pipe_model = build_pipeline_model(model.layers, num_chunks=1) # build dataloaders @@ -54,7 +54,7 @@ def run_trainer(rank, world_size): test_dataloader = get_dataloader(dataset=test_dataset, batch_size=BATCH_SIZE, pin_memory=True) # build criterion - criterion = CrossEntropyLoss(tensor_parallel='1d') + criterion = CrossEntropyLoss() # optimizer optimizer = torch.optim.Adam(pipe_model.parameters(), lr=0.001, weight_decay=0) @@ -78,7 +78,6 @@ def run_trainer(rank, world_size): hook_list = [ hooks.LossHook(), hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), - hooks.AccuracyHook(accuracy_func=Accuracy(tensor_parallel='1d')), hooks.LogMetricByEpochHook(logger), ] @@ -95,7 +94,7 @@ def run_trainer(rank, world_size): # @pytest.mark.skip("This test requires more than 8 GPUs, you should invoke this test script using test.sh provided manually") def test_hybrid_parallel(): world_size = 8 - run_func = partial(run_trainer, world_size=world_size) + run_func = partial(run_trainer, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_engine/test_engine/test_engine_apex_amp.py b/tests/test_engine/test_engine/test_engine_apex_amp.py index c8ee13de1..164ae54bb 100644 --- a/tests/test_engine/test_engine/test_engine_apex_amp.py +++ b/tests/test_engine/test_engine/test_engine_apex_amp.py @@ -1,25 +1,23 @@ # !/usr/bin/env python # -*- encoding: utf-8 -*- -import colossalai import os +from functools import partial +from pathlib import Path + +import colossalai import pytest import torch -import os.path as osp -from pathlib import Path -import torch.nn as nn import torch.multiprocessing as mp - -from torchvision import transforms -from torch.optim import Adam -from colossalai.core import global_context as gpc +import torch.nn as nn from colossalai.amp import AMP_TYPE +from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger -from colossalai.utils import report_memory_usage, get_dataloader -from torchvision.models import resnet18 +from colossalai.utils import free_port, get_dataloader, report_memory_usage +from torch.optim import Adam +from torchvision import transforms from torchvision.datasets import CIFAR10 -from functools import partial - +from torchvision.models import resnet18 # Config BATCH_SIZE = 128 @@ -38,14 +36,14 @@ CONFIG = dict( ) -def run_engine(rank, world_size): +def run_engine(rank, world_size, port): # init dist env colossalai.launch( config=CONFIG, rank=rank, world_size=world_size, host='localhost', - port=29910, + port=port, backend='nccl' ) @@ -104,7 +102,7 @@ def run_engine(rank, world_size): @pytest.mark.dist def test_engine(): world_size = 4 - run_func = partial(run_engine, world_size=world_size) + run_func = partial(run_engine, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_engine/test_engine/test_engine_naive_amp.py b/tests/test_engine/test_engine/test_engine_naive_amp.py index e60b0bbe9..95c620368 100644 --- a/tests/test_engine/test_engine/test_engine_naive_amp.py +++ b/tests/test_engine/test_engine/test_engine_naive_amp.py @@ -1,23 +1,20 @@ -import colossalai import os +from functools import partial +from pathlib import Path + +import colossalai import pytest import torch -import os.path as osp -from pathlib import Path -import torch.nn as nn import torch.multiprocessing as mp - -from torchvision import transforms -from torch.optim import Adam -from colossalai.core import global_context as gpc +import torch.nn as nn from colossalai.amp import AMP_TYPE +from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger -from colossalai.utils import report_memory_usage, get_dataloader -from colossalai.initialize import get_default_parser -from torchvision.models import resnet18 +from colossalai.utils import free_port, get_dataloader, report_memory_usage +from torch.optim import Adam +from torchvision import transforms from torchvision.datasets import CIFAR10 -from functools import partial - +from torchvision.models import resnet18 # Config BATCH_SIZE = 128 @@ -38,14 +35,14 @@ CONFIG = dict( ) -def run_engine(rank, world_size): +def run_engine(rank, world_size, port): # init dist env colossalai.launch( config=CONFIG, rank=rank, world_size=world_size, host='localhost', - port=29911, + port=port, backend='nccl' ) @@ -104,7 +101,7 @@ def run_engine(rank, world_size): @pytest.mark.dist def test_engine(): world_size = 4 - run_func = partial(run_engine, world_size=world_size) + run_func = partial(run_engine, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_engine/test_engine/test_engine_no_amp.py b/tests/test_engine/test_engine/test_engine_no_amp.py index 8bf0baaea..13668e251 100644 --- a/tests/test_engine/test_engine/test_engine_no_amp.py +++ b/tests/test_engine/test_engine/test_engine_no_amp.py @@ -1,23 +1,19 @@ -import colossalai import os +from functools import partial +from pathlib import Path + +import colossalai import pytest import torch -import os.path as osp -from pathlib import Path -import torch.nn as nn import torch.multiprocessing as mp - -from torchvision import transforms -from torch.optim import Adam +import torch.nn as nn from colossalai.core import global_context as gpc -from colossalai.amp import AMP_TYPE from colossalai.logging import get_dist_logger -from colossalai.utils import report_memory_usage, get_dataloader -from colossalai.initialize import get_default_parser -from torchvision.models import resnet18 +from colossalai.utils import free_port, get_dataloader, report_memory_usage +from torch.optim import Adam +from torchvision import transforms from torchvision.datasets import CIFAR10 -from functools import partial - +from torchvision.models import resnet18 # Config BATCH_SIZE = 128 @@ -35,14 +31,14 @@ CONFIG = dict( ) -def run_engine(rank, world_size): +def run_engine(rank, world_size, port): # init dist env colossalai.launch( config=CONFIG, rank=rank, world_size=world_size, host='localhost', - port=29912, + port=port, backend='nccl' ) @@ -101,7 +97,7 @@ def run_engine(rank, world_size): @pytest.mark.dist def test_engine(): world_size = 4 - run_func = partial(run_engine, world_size=world_size) + run_func = partial(run_engine, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_engine/test_engine/test_engine_torch_amp.py b/tests/test_engine/test_engine/test_engine_torch_amp.py index 289a8f1b6..435df81dc 100644 --- a/tests/test_engine/test_engine/test_engine_torch_amp.py +++ b/tests/test_engine/test_engine/test_engine_torch_amp.py @@ -1,23 +1,20 @@ -import colossalai import os +from functools import partial +from pathlib import Path + +import colossalai import pytest import torch -import os.path as osp -from pathlib import Path -import torch.nn as nn import torch.multiprocessing as mp - -from torchvision import transforms -from torch.optim import Adam -from colossalai.core import global_context as gpc +import torch.nn as nn from colossalai.amp import AMP_TYPE +from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger -from colossalai.utils import report_memory_usage, get_dataloader -from colossalai.initialize import get_default_parser -from torchvision.models import resnet18 +from colossalai.utils import free_port, get_dataloader, report_memory_usage +from torch.optim import Adam +from torchvision import transforms from torchvision.datasets import CIFAR10 -from functools import partial - +from torchvision.models import resnet18 # Config BATCH_SIZE = 128 @@ -36,14 +33,14 @@ CONFIG = dict( ) -def run_engine(rank, world_size): +def run_engine(rank, world_size, port): # init dist env colossalai.launch( config=CONFIG, rank=rank, world_size=world_size, host='localhost', - port=29913, + port=port, backend='nccl' ) @@ -102,7 +99,7 @@ def run_engine(rank, world_size): @pytest.mark.dist def test_engine(): world_size = 4 - run_func = partial(run_engine, world_size=world_size) + run_func = partial(run_engine, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_layers/test_1d/test_1d.py b/tests/test_layers/test_1d/test_1d.py index f0f977bea..f4120dc53 100644 --- a/tests/test_layers/test_1d/test_1d.py +++ b/tests/test_layers/test_1d/test_1d.py @@ -1,13 +1,15 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from functools import partial + import pytest import torch import torch.multiprocessing as mp - from colossalai.core import global_context as gpc from colossalai.initialize import launch -from functools import partial +from colossalai.utils import free_port + from checks_1d.check_layer_1d import * CONFIG = dict( @@ -21,12 +23,12 @@ CONFIG = dict( ) -def check_layer(rank, world_size): +def check_layer(rank, world_size, port): launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', - port=29920, + port=port, backend='nccl') check_linear_col() @@ -39,7 +41,7 @@ def check_layer(rank, world_size): @pytest.mark.dist def test_1d(): world_size = 4 - run_func = partial(check_layer, world_size=world_size) + run_func = partial(check_layer, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_layers/test_2d/test_2d.py b/tests/test_layers/test_2d/test_2d.py index 02b0a9cf1..83dc80a95 100644 --- a/tests/test_layers/test_2d/test_2d.py +++ b/tests/test_layers/test_2d/test_2d.py @@ -1,16 +1,17 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from functools import partial + import pytest import torch import torch.multiprocessing as mp - from colossalai.core import global_context as gpc from colossalai.initialize import launch +from colossalai.utils import free_port + from checks_2d.check_layer_2d import * from checks_2d.check_operation_2d import * -from functools import partial - CONFIG = dict( parallel=dict( @@ -34,12 +35,12 @@ def check_layer(): check_layernorm() check_classifier() -def check_layer_and_operation(rank, world_size): +def check_layer_and_operation(rank, world_size, port): launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', - port=29921, + port=port, backend='nccl') # check_operations() @@ -51,7 +52,7 @@ def check_layer_and_operation(rank, world_size): @pytest.mark.dist def test_2d(): world_size = 4 - run_func = partial(check_layer_and_operation, world_size=world_size) + run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_layers/test_2p5d/test_2p5d.py b/tests/test_layers/test_2p5d/test_2p5d.py index f3a180e4d..4de4015bf 100644 --- a/tests/test_layers/test_2p5d/test_2p5d.py +++ b/tests/test_layers/test_2p5d/test_2p5d.py @@ -1,13 +1,15 @@ +from functools import partial + import pytest import torch import torch.multiprocessing as mp - from colossalai.core import global_context as gpc from colossalai.initialize import launch -from checks_2p5d.check_layer_2p5d import check_linear, check_layernorm, check_classifier -from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB -from functools import partial +from colossalai.utils import free_port +from checks_2p5d.check_layer_2p5d import (check_classifier, check_layernorm, + check_linear) +from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB CONFIG = dict( parallel=dict( @@ -29,12 +31,12 @@ def check_layer(): check_classifier() -def check_layer_and_operation(rank, world_size): +def check_layer_and_operation(rank, world_size, port): launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', - port=29922, + port=port, backend='nccl') check_operations() @@ -46,7 +48,7 @@ def check_layer_and_operation(rank, world_size): @pytest.mark.dist def test_2p5d(): world_size = 4 - run_func = partial(check_layer_and_operation, world_size=world_size) + run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_layers/test_3d/test_3d.py b/tests/test_layers/test_3d/test_3d.py index 39e5d8e45..73bdbb5bd 100644 --- a/tests/test_layers/test_3d/test_3d.py +++ b/tests/test_layers/test_3d/test_3d.py @@ -7,6 +7,7 @@ import torch import torch.multiprocessing as mp from colossalai.core import global_context as gpc from colossalai.initialize import launch +from colossalai.utils import free_port from checks_3d.check_layer_3d import * @@ -27,8 +28,8 @@ def check_layer(): # check_loss() -def check_layer_and_operation(rank, world_size): - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29923, backend='nccl') +def check_layer_and_operation(rank, world_size, port): + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') check_layer() gpc.destroy() torch.cuda.empty_cache() @@ -37,7 +38,7 @@ def check_layer_and_operation(rank, world_size): @pytest.mark.dist def test_3d(): world_size = 8 - run_func = partial(check_layer_and_operation, world_size=world_size) + run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_layers/test_sequence/test_sequence.py b/tests/test_layers/test_sequence/test_sequence.py index 56148e673..1ee104eb2 100644 --- a/tests/test_layers/test_sequence/test_sequence.py +++ b/tests/test_layers/test_sequence/test_sequence.py @@ -4,10 +4,11 @@ import pytest import torch import torch.multiprocessing as mp -from colossalai.initialize import launch, get_default_parser +from colossalai.initialize import launch from colossalai.logging import get_dist_logger from checks_seq.check_layer_seq import * from functools import partial +from colossalai.utils import free_port CONFIG = dict( @@ -22,13 +23,13 @@ def check_layer(): check_selfattention() -def run_check_sequence(rank, world_size): +def run_check_sequence(rank, world_size, port): # init dist launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', - port=29924, + port=port, backend='nccl') logger = get_dist_logger() logger.info('Distributed environment is initialzied.', ranks=[0]) @@ -41,7 +42,7 @@ def run_check_sequence(rank, world_size): @pytest.mark.dist def test_sequence(): world_size = 4 - run_func = partial(run_check_sequence, world_size=world_size) + run_func = partial(run_check_sequence, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_trainer/test_pipeline/resnet_config.py b/tests/test_trainer/test_pipeline/resnet_config.py index b0bcc3860..cbf7dd266 100644 --- a/tests/test_trainer/test_pipeline/resnet_config.py +++ b/tests/test_trainer/test_pipeline/resnet_config.py @@ -1,4 +1,5 @@ import os +import model from pathlib import Path BATCH_SIZE = 128 diff --git a/tests/test_trainer/test_pipeline/test_p2p.py b/tests/test_trainer/test_pipeline/test_p2p.py index ce60955ae..283f49fa0 100644 --- a/tests/test_trainer/test_pipeline/test_p2p.py +++ b/tests/test_trainer/test_pipeline/test_p2p.py @@ -1,11 +1,12 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from functools import partial + import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp - from colossalai.communication import (recv_backward, recv_forward, recv_tensor_meta, send_backward, send_backward_recv_forward, send_forward, @@ -15,8 +16,7 @@ from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device -from functools import partial +from colossalai.utils import free_port, get_current_device BATCH_SIZE = 16 SEQ_LENGTH = 64 @@ -123,13 +123,13 @@ def check_comm(size, rank, prev_rank, next_rank, up_group, down_group, logger): check_forward_backward(tensor, grad, rank, logger) -def run_check(rank, world_size): +def run_check(rank, world_size, port): launch( config=CONFIG, rank=rank, world_size=world_size, host='localhost', - port=29932, + port=port, backend='nccl' ) logger = get_dist_logger() @@ -154,7 +154,7 @@ def run_check(rank, world_size): @pytest.mark.dist def test_p2p(): world_size = 4 - run_func = partial(run_check, world_size=world_size) + run_func = partial(run_check, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_trainer/test_pipeline/test_partition.py b/tests/test_trainer/test_pipeline/test_partition.py index 9f011c0e2..61e7e707b 100644 --- a/tests/test_trainer/test_pipeline/test_partition.py +++ b/tests/test_trainer/test_pipeline/test_partition.py @@ -3,25 +3,24 @@ import os.path as osp import pytest import torch import torch.multiprocessing as mp -from torch.utils.data import DataLoader from colossalai.builder.pipeline import build_pipeline_model_from_cfg from colossalai.core import global_context from colossalai.initialize import launch from colossalai.logging import get_dist_logger from functools import partial -import model +from colossalai.utils import free_port DIR_PATH = osp.dirname(osp.realpath(__file__)) CONFIG_PATH = osp.join(DIR_PATH, 'resnet_config.py') -def run_partition(rank, world_size): +def run_partition(rank, world_size, port): launch(config=CONFIG_PATH, rank=rank, world_size=world_size, host='localhost', - port=29933, + port=port, backend='nccl' ) logger = get_dist_logger() @@ -40,7 +39,7 @@ def run_partition(rank, world_size): @pytest.mark.dist def test_partition(): world_size = 4 - run_func = partial(run_partition, world_size=world_size) + run_func = partial(run_partition, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_trainer/test_pipeline/test_pipeline_schedule.py b/tests/test_trainer/test_pipeline/test_pipeline_schedule.py index be2f7ab30..d3c876c9c 100644 --- a/tests/test_trainer/test_pipeline/test_pipeline_schedule.py +++ b/tests/test_trainer/test_pipeline/test_pipeline_schedule.py @@ -1,26 +1,23 @@ # referenced from Megatron and used to testify communication -import colossalai import os import os.path as osp +from functools import partial +from pathlib import Path + +import colossalai import pytest import torch import torch.multiprocessing as mp -import model - from colossalai.builder import build_pipeline_model_from_cfg -from colossalai.communication import p2p as p2p_communication -from colossalai.communication.utils import send_tensor_meta, recv_tensor_meta -from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.utils import print_rank_0, get_current_device, get_dataloader from colossalai.engine.schedule import PipelineSchedule -from torchvision.datasets import CIFAR10 +from colossalai.initialize import launch +from colossalai.utils import free_port, get_dataloader, print_rank_0 from torchvision import transforms -from pathlib import Path -from functools import partial +from torchvision.datasets import CIFAR10 +import model BATCH_SIZE = 32 NUM_MICRO = 8 @@ -30,12 +27,12 @@ DIR_PATH = osp.dirname(osp.realpath(__file__)) CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py') -def run_schedule(rank, world_size): +def run_schedule(rank, world_size, port): launch(config=CONFIG_PATH, rank=rank, world_size=world_size, host='localhost', - port=29934, + port=port, backend='nccl') # build model @@ -86,7 +83,7 @@ def run_schedule(rank, world_size): @pytest.mark.dist def test_pipeline_schedule(): world_size = 4 - run_func = partial(run_schedule, world_size=world_size) + run_func = partial(run_schedule, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py index af4180ade..599efd883 100644 --- a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py +++ b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py @@ -11,7 +11,7 @@ from colossalai.amp.amp_type import AMP_TYPE from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.trainer import Trainer -from colossalai.utils import MultiTimer, get_dataloader +from colossalai.utils import MultiTimer, free_port, get_dataloader from torch.optim import Adam from torchvision import transforms from torchvision.datasets import CIFAR10 @@ -26,8 +26,8 @@ CONFIG = dict( fp16=dict(mode=AMP_TYPE.TORCH)) -def run_trainer_no_pipeline(rank, world_size): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29930, backend='nccl') +def run_trainer_no_pipeline(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # build model model = resnet18(num_classes=10) @@ -88,7 +88,7 @@ def run_trainer_no_pipeline(rank, world_size): @pytest.mark.dist def test_trainer_no_pipeline(): world_size = 4 - run_func = partial(run_trainer_no_pipeline, world_size=world_size) + run_func = partial(run_trainer_no_pipeline, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_trainer/test_trainer_with_pipe_schedule.py b/tests/test_trainer/test_trainer_with_pipe_schedule.py index c6bb5ad15..8dffc3fc7 100644 --- a/tests/test_trainer/test_trainer_with_pipe_schedule.py +++ b/tests/test_trainer/test_trainer_with_pipe_schedule.py @@ -12,7 +12,7 @@ from colossalai.core import global_context as gpc from colossalai.engine.schedule import PipelineSchedule from colossalai.logging import get_dist_logger from colossalai.trainer import Trainer -from colossalai.utils import MultiTimer, get_dataloader +from colossalai.utils import MultiTimer, free_port, get_dataloader from torch.optim import Adam from torchvision import transforms from torchvision.datasets import CIFAR10 @@ -25,8 +25,8 @@ NUM_EPOCHS = 200 CONFIG = dict(parallel=dict(pipeline=2, ), ) -def run_trainer_with_pipeline(rank, world_size): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29931, backend='nccl') +def run_trainer_with_pipeline(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # build model model = resnet18(num_classes=10) @@ -99,7 +99,7 @@ def run_trainer_with_pipeline(rank, world_size): @pytest.mark.dist def test_trainer_with_pipeline(): world_size = 4 - run_func = partial(run_trainer_with_pipeline, world_size=world_size) + run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_utils/test_gradient_accumluation.py b/tests/test_utils/test_gradient_accumluation.py index 6a709f7db..c7471d77c 100644 --- a/tests/test_utils/test_gradient_accumluation.py +++ b/tests/test_utils/test_gradient_accumluation.py @@ -1,21 +1,19 @@ -import colossalai import os +from functools import partial +from pathlib import Path + +import colossalai import pytest import torch import torch.multiprocessing as mp import torch.nn as nn - -from functools import partial -from pathlib import Path -from torchvision import transforms -from torch.optim import Adam from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger -from colossalai.utils import report_memory_usage, get_dataloader -from colossalai.initialize import get_default_parser -from torchvision.models import resnet18 +from colossalai.utils import free_port, get_dataloader +from torch.optim import Adam +from torchvision import transforms from torchvision.datasets import CIFAR10 - +from torchvision.models import resnet18 # Config BATCH_SIZE = 16 @@ -32,7 +30,7 @@ CONFIG = dict( ) -def run_no_pipeline(rank, world_size): +def run_no_pipeline(rank, world_size, port): # init dist env colossalai.launch( @@ -40,7 +38,7 @@ def run_no_pipeline(rank, world_size): rank=rank, world_size=world_size, host='localhost', - port=29500, + port=port, backend='nccl' ) @@ -110,7 +108,7 @@ def run_no_pipeline(rank, world_size): @pytest.mark.dist def test_engine(): world_size = 4 - func = partial(run_no_pipeline, world_size=world_size) + func = partial(run_no_pipeline, world_size=world_size, port=free_port()) mp.spawn(func, nprocs=world_size) diff --git a/tests/test_zero_data_parallel/test_zero_level_2.py b/tests/test_zero_data_parallel/test_zero_level_2.py index 5da282255..9bdd1b124 100644 --- a/tests/test_zero_data_parallel/test_zero_level_2.py +++ b/tests/test_zero_data_parallel/test_zero_level_2.py @@ -2,18 +2,18 @@ # -*- encoding: utf-8 -*- import os -import pytest -import torch -import torch.multiprocessing as mp +from functools import partial from pathlib import Path import colossalai +import pytest +import torch +import torch.multiprocessing as mp from colossalai.core import global_context as gpc -from colossalai.utils import get_dataloader +from colossalai.utils import free_port, get_dataloader from torchvision import transforms -from torchvision.models import resnet18 from torchvision.datasets import CIFAR10 -from functools import partial +from torchvision.models import resnet18 BATCH_SIZE = 16 IMG_SIZE = 224 @@ -34,12 +34,12 @@ CONFIG = dict( ) -def run_dist(rank, world_size): +def run_dist(rank, world_size, port): colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', - port=29940, + port=port, backend='nccl') # build model @@ -94,7 +94,7 @@ def run_dist(rank, world_size): @pytest.mark.dist def test_zero_level_2(): world_size = 4 - run_func = partial(run_dist, world_size=world_size) + run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_zero_data_parallel/test_zero_level_3.py b/tests/test_zero_data_parallel/test_zero_level_3.py index f1fe45b2b..2655210db 100644 --- a/tests/test_zero_data_parallel/test_zero_level_3.py +++ b/tests/test_zero_data_parallel/test_zero_level_3.py @@ -2,18 +2,18 @@ # -*- encoding: utf-8 -*- import os -import pytest -import torch -import torch.multiprocessing as mp +from functools import partial from pathlib import Path import colossalai +import pytest +import torch +import torch.multiprocessing as mp from colossalai.core import global_context as gpc -from colossalai.utils import get_dataloader +from colossalai.utils import free_port, get_dataloader from torchvision import transforms -from torchvision.models import resnet18 from torchvision.datasets import CIFAR10 -from functools import partial +from torchvision.models import resnet18 BATCH_SIZE = 16 IMG_SIZE = 224 @@ -46,12 +46,12 @@ CONFIG = dict( ) -def run_dist(rank, world_size): +def run_dist(rank, world_size, port): colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', - port=29941, + port=port, backend='nccl') # build model @@ -106,7 +106,7 @@ def run_dist(rank, world_size): @pytest.mark.dist def test_zero_level_3(): world_size = 4 - run_func = partial(run_dist, world_size=world_size) + run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_zero_tensor_parallel/test_vit_2d_level_2.py b/tests/test_zero_tensor_parallel/test_vit_2d_level_2.py index 58c1e98b9..9d215f5ae 100644 --- a/tests/test_zero_tensor_parallel/test_vit_2d_level_2.py +++ b/tests/test_zero_tensor_parallel/test_vit_2d_level_2.py @@ -13,7 +13,7 @@ import torch.multiprocessing as mp from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.nn import CrossEntropyLoss -from colossalai.utils import get_dataloader +from colossalai.utils import free_port, get_dataloader from model_zoo.vit import vit_lite_depth7_patch4_32 from torchvision import transforms from torchvision.datasets import CIFAR10 @@ -40,11 +40,11 @@ def train_epoch(engine, train_dataloader): return avg_loss -def run_2d_parallel_vision_transformer_level_2(rank, world_size): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29950, backend='nccl') +def run_2d_parallel_vision_transformer_level_2(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # build model - model = vit_lite_depth7_patch4_32(tensor_parallel='2d') + model = vit_lite_depth7_patch4_32() # build dataloader# build dataloaders train_dataset = CIFAR10(root=Path(os.environ['DATA']), @@ -62,7 +62,7 @@ def run_2d_parallel_vision_transformer_level_2(rank, world_size): # build optimizer and loss optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - criterion = CrossEntropyLoss(tensor_parallel='2d') + criterion = CrossEntropyLoss() engine, train_dataloader, *args = colossalai.initialize(model=model, optimizer=optimizer, @@ -90,7 +90,7 @@ def run_2d_parallel_vision_transformer_level_2(rank, world_size): @pytest.mark.dist def test_2d_vit_zero_level_2(): world_size = 8 - run_func = partial(run_2d_parallel_vision_transformer_level_2, world_size=world_size) + run_func = partial(run_2d_parallel_vision_transformer_level_2, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_zero_tensor_parallel/test_vit_2d_level_3.py b/tests/test_zero_tensor_parallel/test_vit_2d_level_3.py index 0b08a58f2..149fefb72 100644 --- a/tests/test_zero_tensor_parallel/test_vit_2d_level_3.py +++ b/tests/test_zero_tensor_parallel/test_vit_2d_level_3.py @@ -13,7 +13,7 @@ import torch.multiprocessing as mp from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.nn import CrossEntropyLoss -from colossalai.utils import get_dataloader +from colossalai.utils import free_port, get_dataloader from model_zoo.vit import vit_lite_depth7_patch4_32 from torchvision import transforms from torchvision.datasets import CIFAR10 @@ -40,11 +40,11 @@ def train_epoch(engine, train_dataloader): return avg_loss -def run_2d_parallel_vision_transformer_level_3(rank, world_size): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29951, backend='nccl') +def run_2d_parallel_vision_transformer_level_3(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # build model - model = vit_lite_depth7_patch4_32(tensor_parallel='2d') + model = vit_lite_depth7_patch4_32() # build dataloader# build dataloaders train_dataset = CIFAR10(root=Path(os.environ['DATA']), @@ -62,7 +62,7 @@ def run_2d_parallel_vision_transformer_level_3(rank, world_size): # build optimizer and loss optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - criterion = CrossEntropyLoss(tensor_parallel='2d') + criterion = CrossEntropyLoss() engine, train_dataloader, *args = colossalai.initialize(model=model, optimizer=optimizer, @@ -91,7 +91,7 @@ def run_2d_parallel_vision_transformer_level_3(rank, world_size): @pytest.mark.skip("Level 3 has unknown bug so skip this test for now") def test_3d_vit_zero_level_3(): world_size = 8 - run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size) + run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size)