Hotfix/Colossalai layers (#92)

* optimized 1d layer apis; reorganized nn.layer modules; fixed tests

* fixed 2.5d runtime issue

* reworked split batch, now called in trainer.schedule.load_batch

Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
pull/95/head
アマデウス 2021-12-29 23:32:10 +08:00 committed by GitHub
parent 0fedef4f3c
commit 01a80cd86d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
71 changed files with 1033 additions and 773 deletions

View File

@ -2,7 +2,7 @@ BATCH_SIZE = 512
LEARNING_RATE = 2e-3 LEARNING_RATE = 2e-3
WEIGHT_DECAY = 3e-2 WEIGHT_DECAY = 3e-2
TENSOR_PARALLEL_SIZE = 4 TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d' TENSOR_PARALLEL_MODE = '1d'
NUM_EPOCHS = 200 NUM_EPOCHS = 200

View File

@ -72,13 +72,11 @@ def train_cifar():
os.mkdir(log_path) os.mkdir(log_path)
logger.log_to_file(log_path) logger.log_to_file(log_path)
tp = gpc.config.parallel.tensor.mode model = vit_lite_depth7_patch4_32()
model = vit_lite_depth7_patch4_32(tensor_parallel=tp)
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size) train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp) criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
@ -107,7 +105,7 @@ def train_cifar():
LogMetricByStepHook(), LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger), # LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger), # LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)), AccuracyHook(accuracy_func=Accuracy()),
LossHook(), LossHook(),
ThroughputHook(), ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False) LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False)

View File

@ -4,7 +4,7 @@ TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3 LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3 WEIGHT_DECAY = 0.3
TENSOR_PARALLEL_SIZE = 4 TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d' TENSOR_PARALLEL_MODE = '1d'
NUM_EPOCHS = 300 NUM_EPOCHS = 300

View File

@ -159,14 +159,12 @@ def train_imagenet():
os.mkdir(log_path) os.mkdir(log_path)
logger.log_to_file(log_path) logger.log_to_file(log_path)
tp = gpc.config.parallel.tensor.mode model = vit_small_patch16_224(num_classes=100, init_method='jax')
model = vit_small_patch16_224(tensor_parallel=tp, num_classes=100, init_method='jax')
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size) train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size) test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp) criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
@ -192,7 +190,7 @@ def train_imagenet():
LogMetricByStepHook(), LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger), # LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger), # LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)), AccuracyHook(accuracy_func=Accuracy()),
LossHook(), LossHook(),
ThroughputHook(), ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True) LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)

View File

@ -4,7 +4,7 @@ TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3 LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3 WEIGHT_DECAY = 0.3
TENSOR_PARALLEL_SIZE = 4 TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d' TENSOR_PARALLEL_MODE = '1d'
NUM_EPOCHS = 300 NUM_EPOCHS = 300

View File

@ -159,14 +159,12 @@ def train_imagenet():
os.mkdir(log_path) os.mkdir(log_path)
logger.log_to_file(log_path) logger.log_to_file(log_path)
tp = gpc.config.parallel.tensor.mode model = vit_small_patch16_224(num_classes=1000, init_method='jax')
model = vit_small_patch16_224(tensor_parallel=tp, num_classes=1000, init_method='jax')
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size) train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size) test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp) criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
@ -192,7 +190,7 @@ def train_imagenet():
LogMetricByStepHook(), LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger), # LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger), # LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)), AccuracyHook(accuracy_func=Accuracy()),
LossHook(), LossHook(),
ThroughputHook(), ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True) LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)

View File

@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence'] ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence']
TENSOR_PARALLEL_MODE = 'tensor_parallel_mode'
# intializer # intializer
INITIALIZER_MAPPING = { INITIALIZER_MAPPING = {
@ -16,6 +17,9 @@ INITIALIZER_MAPPING = {
'sequence': 'Initializer_Sequence' 'sequence': 'Initializer_Sequence'
} }
# 1D parallel
PARALLEL_INPUT_1D = 'parallel_input_1d'
# 2D paralllel # 2D paralllel
SUMMA_DIM = 'SUMMA_DIM' SUMMA_DIM = 'SUMMA_DIM'

View File

@ -1,17 +1,18 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os
import random import random
from typing import Union from typing import Union
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING, TENSOR_PARALLEL_MODE
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.context.config import Config from colossalai.context.config import Config
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
from .parallel_mode import ParallelMode from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode from .random import add_seed, get_seeds, set_mode
@ -386,6 +387,7 @@ class ParallelContext:
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']: if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
tensor_parallel_mode = parallel_config['tensor']['mode'] tensor_parallel_mode = parallel_config['tensor']['mode']
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}" assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
os.environ[TENSOR_PARALLEL_MODE] = str(tensor_parallel_mode)
self.check_sanity() self.check_sanity()
pg_init = [] pg_init = []

View File

@ -1,12 +1,13 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os
import torch.distributed as dist import torch.distributed as dist
from colossalai.context import Config from colossalai.context import Config
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from colossalai.constants import PARALLEL_INPUT_1D
@DIST_GROUP_INITIALIZER.register_module @DIST_GROUP_INITIALIZER.register_module
@ -29,6 +30,7 @@ class Initializer_1D(ProcessGroupInitializer):
process_group = None process_group = None
group_world_size = None group_world_size = None
mode = ParallelMode.PARALLEL_1D mode = ParallelMode.PARALLEL_1D
os.environ[PARALLEL_INPUT_1D] = ''
for i in range(self.num_group): for i in range(self.num_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]

View File

@ -10,7 +10,7 @@ from typing import Iterable, Union, List, Callable
from .._base_engine import Engine from .._base_engine import Engine
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.nn.layer import split_batch
class BaseSchedule(ABC): class BaseSchedule(ABC):
"""A basic helper class to control the process of training or evaluation. """A basic helper class to control the process of training or evaluation.
@ -59,7 +59,11 @@ class BaseSchedule(ABC):
else: else:
data, label = batch_data data, label = batch_data
data, label = self._to_list(data), self._to_list(label) if isinstance(label, (tuple, list)):
self.batch_size = label[0].size(0)
else:
self.batch_size = label.size(0)
data, label = self._to_list(split_batch(data)), self._to_list(split_batch(label))
return self._move_to_device(data), self._move_to_device(label) return self._move_to_device(data), self._move_to_device(label)
def pre_processing(self, engine: Engine): def pre_processing(self, engine: Engine):

View File

@ -1,3 +1,9 @@
from .colossalai_layer import * from .colossalai_layer import *
from .fused_bias_gelu import bias_gelu_impl from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
from .utils import *
from .vanilla import *
from .wrapper import * from .wrapper import *

View File

@ -1,231 +0,0 @@
import math
from typing import Callable, Optional
from colossalai.utils import get_current_device
from torch import dtype, nn
from torch.nn.modules.activation import *
from torch.nn.modules.adaptive import *
from torch.nn.modules.batchnorm import *
from torch.nn.modules.channelshuffle import *
from torch.nn.modules.conv import *
from torch.nn.modules.distance import *
from torch.nn.modules.dropout import *
from torch.nn.modules.flatten import *
from torch.nn.modules.fold import *
from torch.nn.modules.instancenorm import *
from torch.nn.modules.linear import *
from torch.nn.modules.normalization import *
from torch.nn.modules.padding import *
from torch.nn.modules.pixelshuffle import *
from torch.nn.modules.pooling import *
from torch.nn.modules.rnn import *
from torch.nn.modules.sparse import *
from torch.nn.modules.transformer import *
from torch.nn.modules.upsampling import *
from .. import init as init
from .vanilla import *
from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
_parallel_linear = {'1d_col': Linear1D_Col, '1d_row': Linear1D_Row, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = {
None: VanillaClassifier,
'1d': VanillaClassifier,
'2d': Classifier2D,
'2.5d': Classifier2p5D,
'3d': Classifier3D
}
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
_parallel_embedding = {'3d': Embedding3D}
_parallel_patchembedding = {
None: VanillaPatchEmbedding,
'1d': VanillaPatchEmbedding,
'2d': PatchEmbedding2D,
'2.5d': PatchEmbedding2p5D,
'3d': PatchEmbedding3D
}
class Linear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
tensor_parallel: Optional[str] = None,
**kwargs) -> None:
super().__init__()
if tensor_parallel is None:
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
if bias:
bias_initializer(self.layer.bias, fan_in=in_features)
else:
self.layer = _parallel_linear[tensor_parallel](
in_features,
out_features,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
**kwargs,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)
class LayerNorm(nn.Module):
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None, tensor_parallel: Optional[str] = None) -> None:
super().__init__()
if tensor_parallel in [None, '1d']:
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
else:
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
@property
def weight(self):
return self.norm.weight
@property
def bias(self):
return self.norm.bias
def forward(self, *args):
return self.norm(*args)
class Embedding(nn.Module):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: dtype = None,
weight_initializer: Callable = init.normal_(),
tensor_parallel: Optional[str] = None,
*args,
**kwargs) -> None:
super().__init__()
if tensor_parallel in [None, '1d']:
self.embed = nn.Embedding(num_embeddings,
embedding_dim,
padding_idx=padding_idx,
device=get_current_device(),
dtype=dtype,
*args,
**kwargs)
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
else:
self.embed = _parallel_embedding[tensor_parallel](
num_embeddings,
embedding_dim,
padding_idx=padding_idx,
dtype=dtype,
weight_initializer=weight_initializer,
*args,
**kwargs,
)
@property
def weight(self):
return self.embed.weight
def forward(self, *args):
return self.embed(*args)
class PatchEmbedding(nn.Module):
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
dtype: dtype = None,
flatten: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_(),
tensor_parallel: Optional[str] = None) -> None:
super().__init__()
self.embed = _parallel_patchembedding[tensor_parallel](
img_size,
patch_size,
in_chans,
embed_size,
dtype=dtype,
flatten=flatten,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
position_embed_initializer=position_embed_initializer,
)
@property
def weight(self):
return self.embed.weight
@property
def bias(self):
return self.embed.bias
@property
def pos_embed(self):
return self.embed.pos_embed
@property
def cls_token(self):
return self.embed.cls_token
def forward(self, *args):
return self.embed(*args)
class Classifier(nn.Module):
def __init__(self,
in_features: int,
num_classes: int,
weight: nn.Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
tensor_parallel: Optional[str] = None) -> None:
super().__init__()
self.layer = _parallel_classifier[tensor_parallel](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)

View File

@ -0,0 +1,7 @@
from ._utils import split_batch
from .dropout import Dropout
from .embedding import Embedding, PatchEmbedding
from .linear import Classifier, Linear
from .normalization import LayerNorm
__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'split_batch']

View File

@ -0,0 +1,19 @@
from torch import Tensor
from ..parallel_2d._operation import split_tensor_2d
from ..parallel_2p5d._operation import split_tensor_2p5d
from ..parallel_3d._operation import split_tensor_3d
from ..utils import get_tensor_parallel_mode
_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_tensor_3d}
def split_batch(input_) -> Tensor:
tensor_parallel_mode = get_tensor_parallel_mode()
if tensor_parallel_mode in _parallel_split_batch:
if isinstance(input_, (tuple, list)):
return tuple(map(_parallel_split_batch[tensor_parallel_mode], input_))
else:
return _parallel_split_batch[tensor_parallel_mode](input_)
else:
return input_

View File

@ -0,0 +1,23 @@
from contextlib import nullcontext
import torch.nn as nn
from colossalai.context import ParallelMode, seed
from colossalai.utils import conditional_context
from ..parallel_1d import *
from ..utils import get_tensor_parallel_mode
class Dropout(nn.Module):
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
super().__init__()
self.tensor_parallel = get_tensor_parallel_mode()
if self.tensor_parallel == '1d':
self.drop = Dropout1D(p, inplace)
else:
self.drop = nn.Dropout(p, inplace)
def forward(self, *args):
cm = nullcontext() if self.tensor_parallel in ['None', '1d'] else seed(ParallelMode.TENSOR)
with cm:
return self.drop(*args)

View File

@ -0,0 +1,107 @@
import math
from typing import Callable, Optional
from colossalai.utils import get_current_device
from torch import dtype, nn
from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
from ..parallel_2p5d import *
from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode
from ..vanilla import *
_parallel_embedding = {'1d': Embedding1D, '2d': Embedding2D, '2.5d': Embedding2p5D, '3d': Embedding3D}
_parallel_patchembedding = {
'None': VanillaPatchEmbedding,
'1d': VanillaPatchEmbedding,
'2d': PatchEmbedding2D,
'2.5d': PatchEmbedding2p5D,
'3d': PatchEmbedding3D
}
class Embedding(nn.Module):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel == 'None':
self.embed = nn.Embedding(num_embeddings,
embedding_dim,
padding_idx=padding_idx,
device=get_current_device(),
dtype=dtype,
*args,
**kwargs)
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
else:
self.embed = _parallel_embedding[tensor_parallel](
num_embeddings,
embedding_dim,
padding_idx=padding_idx,
dtype=dtype,
weight_initializer=weight_initializer,
*args,
**kwargs,
)
@property
def weight(self):
return self.embed.weight
def forward(self, *args):
return self.embed(*args)
class PatchEmbedding(nn.Module):
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
dtype: dtype = None,
flatten: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
self.embed = _parallel_patchembedding[tensor_parallel](
img_size,
patch_size,
in_chans,
embed_size,
dtype=dtype,
flatten=flatten,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
position_embed_initializer=position_embed_initializer,
)
@property
def weight(self):
return self.embed.weight
@property
def bias(self):
return self.embed.bias
@property
def pos_embed(self):
return self.embed.pos_embed
@property
def cls_token(self):
return self.embed.cls_token
def forward(self, *args):
return self.embed(*args)

View File

@ -0,0 +1,97 @@
import math
from typing import Callable, Optional
from colossalai.nn.layer.parallel_1d.layers import Classifier1D
from colossalai.utils import get_current_device
from torch import dtype, nn
from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
from ..parallel_2p5d import *
from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode
from ..vanilla import *
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = {
'None': VanillaClassifier,
'1d': Classifier1D,
'2d': Classifier2D,
'2.5d': Classifier2p5D,
'3d': Classifier3D
}
class Linear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
**kwargs) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel == 'None':
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
if bias:
bias_initializer(self.layer.bias, fan_in=in_features)
else:
self.layer = _parallel_linear[tensor_parallel](
in_features,
out_features,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
**kwargs,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)
class Classifier(nn.Module):
def __init__(
self,
in_features: int,
num_classes: int,
weight: nn.Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)
) -> None:
super().__init__()
self.layer = _parallel_classifier[get_tensor_parallel_mode()](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)

View File

@ -0,0 +1,35 @@
from typing import Optional
from colossalai.utils import get_current_device
from torch import nn
from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
from ..parallel_2p5d import *
from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode
from ..vanilla import *
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
class LayerNorm(nn.Module):
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel in ['None', '1d']:
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
else:
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
@property
def weight(self):
return self.norm.weight
@property
def bias(self):
return self.norm.bias
def forward(self, *args):
return self.norm(*args)

View File

@ -1,35 +0,0 @@
# adapted from Megatron-LM
# https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/megatron/model/fused_bias_gelu.py
import torch
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply

View File

@ -1,4 +1,4 @@
from .layers import Linear1D_Col, Linear1D_Row from .layers import Dropout1D, Embedding1D, Linear1D, Linear1D_Col, Linear1D_Row
from .layers import MixedFusedLayerNorm1D as LayerNorm1D from .layers import MixedFusedLayerNorm1D as LayerNorm1D
__all__ = ['Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D'] __all__ = ['Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D', 'Embedding1D', 'Dropout1D']

View File

@ -1,12 +1,21 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import PARALLEL_INPUT_1D
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from .._common_utils import divide from ..utils import divide
def set_parallel_input(input_parallel: bool):
os.environ[PARALLEL_INPUT_1D] = 'true' if input_parallel else ''
def get_parallel_input():
return bool(os.environ[PARALLEL_INPUT_1D])
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank): def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):

View File

@ -3,10 +3,10 @@
import math import math
import numbers import numbers
from contextlib import nullcontext
from typing import Callable, Tuple from typing import Callable, Tuple
import torch import torch
import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.communication import broadcast from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
@ -14,13 +14,122 @@ from colossalai.core import global_context as gpc
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch import Tensor from torch import Tensor, dtype
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition
from ._operation import FusedLayerNormAffineFunction1D from ._operation import FusedLayerNormAffineFunction1D
from ._utils import (gather_forward_split_backward, reduce_grad, reduce_input, split_forward_gather_backward) from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
split_forward_gather_backward)
@LAYERS.register_module
class Linear1D(torch.nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
gather_output: bool = False,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
parallel_input = get_parallel_input()
if not parallel_input:
self.layer = Linear1D_Col(in_features,
out_features,
bias=bias,
dtype=dtype,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
else:
self.layer = Linear1D_Row(in_features,
out_features,
bias=bias,
dtype=dtype,
parallel_input=parallel_input,
skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, input_: Tensor) -> Tensor:
return self.layer(input_)
@LAYERS.register_module
class Classifier1D(ParallelLayer):
"""RowLinear with given weight"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
self.parallel_input = get_parallel_input()
# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = Parameter(torch.empty(self.num_classes, self.input_size_per_partition, **factory_kwargs))
self.has_weight = True
if bias:
self.bias = Parameter(torch.empty(self.num_classes, **factory_kwargs))
else:
self.bias = None
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.num_classes
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D)
def _set_tensor_parallel_attributes(self):
if self.has_weight:
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
input_ = input_
else:
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
output_parallel = F.linear(input_, self.weight)
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
output = output + self.bias
return output
@LAYERS.register_module @LAYERS.register_module
@ -77,6 +186,7 @@ class Linear1D_Col(ParallelLayer):
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer) self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes() self._set_tensor_parallel_attributes()
set_parallel_input(True)
def reset_parameters(self, weight_initializer, bias_initializer) -> None: def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features fan_in, fan_out = self.in_features, self.out_features
@ -158,6 +268,7 @@ class Linear1D_Row(ParallelLayer):
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer) self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes() self._set_tensor_parallel_attributes()
set_parallel_input(False)
def reset_parameters(self, weight_initializer, bias_initializer) -> None: def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features fan_in, fan_out = self.in_features, self.out_features
@ -208,3 +319,68 @@ class MixedFusedLayerNorm1D(torch.nn.Module):
def forward(self, input): def forward(self, input):
return FusedLayerNormAffineFunction1D.apply(input, self.weight, self.bias, self.normalized_shape, self.eps) return FusedLayerNormAffineFunction1D.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
@LAYERS.register_module
class Embedding1D(ParallelLayer):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size)
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.weight = Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
return output
@LAYERS.register_module
class Dropout1D(ParallelLayer):
def __init__(self, p: float = 0.5, inplace: bool = False):
super().__init__()
self.parallel_input = get_parallel_input()
self.p = p
self.inplace = inplace
def forward(self, input_: Tensor) -> Tensor:
cm = nullcontext() if not self.parallel_input else seed(ParallelMode.TENSOR)
with cm:
output = F.dropout(input_, self.p, self.training, self.inplace)
return output

View File

@ -1,6 +1,6 @@
from ._operation import reduce_by_batch_2d, split_batch_2d from ._operation import reduce_by_batch_2d, split_tensor_2d
from .layers import Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D from .layers import Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D
__all__ = [ __all__ = [
'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D' 'split_tensor_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D'
] ]

View File

@ -2,7 +2,7 @@ from typing import Any, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.communication.collective import (all_gather, all_reduce, reduce_scatter) from colossalai.communication.collective import (all_gather, all_reduce, reduce, reduce_scatter)
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -595,7 +595,9 @@ class SplitFirst(torch.autograd.Function):
return grad, None, None return grad, None, None
def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor: def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
if input_.size(dim) <= 1:
return input_
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL), return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous() dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous()
@ -603,17 +605,28 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
class reduce_by_batch_2d(torch.autograd.Function): class reduce_by_batch_2d(torch.autograd.Function):
"""All-reduce the input from the model parallel region.""" """All-reduce the input from the model parallel region."""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_, reduce_mean: bool = False):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL)) output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
return input_ if reduce_mean:
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
return output / reduce_size
return output
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_): def forward(ctx, input_, reduce_mean: bool = False):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL)) output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
return input_.clone() ctx.reduce_mean = reduce_mean
if reduce_mean:
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
ctx.reduce_size = reduce_size
return output.clone() / reduce_size
return output.clone()
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx, grad_output): def backward(ctx, output_grad):
return grad_output if ctx.reduce_mean:
return output_grad / ctx.reduce_size, None
else:
return output_grad, None

View File

@ -13,9 +13,9 @@ from colossalai.utils import get_current_device
from torch import Tensor, dtype from torch import Tensor, dtype
from torch.nn import Parameter from torch.nn import Parameter
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple) from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ._operation import (Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d, split_batch_2d) from ._operation import Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d
from ._utils import assert_summa_initialization, get_summa_dim_from_env from ._utils import assert_summa_initialization, get_summa_dim_from_env
@ -257,8 +257,6 @@ class PatchEmbedding2D(ParallelLayer):
assert H == self.img_size[0] and W == self.img_size[1], \ assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
input_ = split_batch_2d(input_)
weight = all_gather_weight_2d.apply(self.weight, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL) weight = all_gather_weight_2d.apply(self.weight, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
bias = all_gather_weight_2d.apply(self.bias, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL) bias = all_gather_weight_2d.apply(self.bias, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
@ -318,8 +316,6 @@ class Embedding2D(ParallelLayer):
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_2d(input_)
weight = all_gather_weight_2d.apply(self.weight, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL) weight = all_gather_weight_2d.apply(self.weight, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)

View File

@ -1,7 +1,7 @@
from ._operation import reduce_by_batch_2p5d, split_batch_2p5d from ._operation import reduce_by_batch_2p5d, split_tensor_2p5d
from .layers import Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D from .layers import Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D
__all__ = [ __all__ = [
'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D', 'split_tensor_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
'Embedding2p5D' 'Embedding2p5D'
] ]

View File

@ -22,7 +22,7 @@ def get_parallel_rank(parallel_mode: ParallelMode):
return gpc.get_local_rank(parallel_mode) return gpc.get_local_rank(parallel_mode)
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor: def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous() dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
@ -120,30 +120,53 @@ class Matmul_AB_2p5D(torch.autograd.Function):
ctx.save_for_backward(A, B) ctx.save_for_backward(A, B)
A_shape = A.shape A_shape = A.shape
A = A.reshape((-1, A_shape[-1])).contiguous() A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape B_shape = B.shape
B = B.reshape((-1, B_shape[-1])).contiguous() B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[-1]) C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode) - 1)] # use circular buffer to store the communication tensor
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode) - 1)] # 2 is enough for all cases
A_list.insert(gpc.get_local_rank(row_parallel_mode), A) A_list = [torch.empty_like(A) for _ in range(2)]
B_list.insert(gpc.get_local_rank(col_parallel_mode), B) B_list = [torch.empty_like(B) for _ in range(2)]
op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
op_a.wait() row_group = gpc.get_group(row_parallel_mode)
op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True) col_group = gpc.get_group(col_parallel_mode)
for op in [op_a, op_b]:
op.wait() src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2
opb = [None] * 2
A_list[0].copy_(A)
B_list[0].copy_(B)
opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True)
opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True)
cur = 0
for i in range(tesseract_dim): for i in range(tesseract_dim):
src_a = i + tesseract_dim * row_rank if i != tesseract_dim - 1:
src_b = i + tesseract_dim * col_rank A_list[1 - cur].copy_(A)
src_a = src_a % tesseract_dim opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True)
src_b = src_b % tesseract_dim B_list[1 - cur].copy_(B)
A_temp = A_list[src_a] opb[1 - cur] = dist.broadcast(B_list[1 - cur],
B_temp = B_list[src_b] src=src_b + tesseract_dim,
torch.addmm(C, A_temp, B_temp, out=C) group=col_group,
async_op=True)
if opa[cur] is not None:
opa[cur].wait()
if opb[cur] is not None:
opb[cur].wait()
torch.addmm(C, A_list[cur], B_list[cur], out=C)
cur = 1 - cur
src_a += 1
src_b += tesseract_dim
out = C.reshape(out_shape) out = C.reshape(out_shape)
if ctx: if ctx:
@ -201,20 +224,55 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
C_shape = (A.shape[0], B.shape[0]) C_shape = (A.shape[0], B.shape[0])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(tesseract_dim): # use circular buffer to store the communication tensor
B_temp = B.clone() # 2 is enough for all cases
src_b = col_rank + i * tesseract_dim + dep_rank * ( B_list = [torch.empty_like(B) for _ in range(2)]
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ C_list = [torch.empty_like(C) for _ in range(2)]
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(B_temp, src=src_b, group=gpc.get_group(col_parallel_mode))
C_temp = torch.matmul(A, B_temp.transpose(0, 1))
src_c = i + row_rank * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c, group=gpc.get_group(row_parallel_mode))
if i == col_rank:
C = C_temp.clone()
row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode)
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_c = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opb = [None] * 2
opr = [None] * 2
B_list[0].copy_(B)
opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True)
cur = 0
for i in range(tesseract_dim):
if i != tesseract_dim - 1:
B_list[1 - cur].copy_(B)
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
src=src_b + tesseract_dim,
group=col_group,
async_op=True)
if opr[cur] is not None:
opr[cur].wait()
if i - 2 == col_rank:
C.copy_(C_list[cur])
if opb[cur] is not None:
opb[cur].wait()
torch.matmul(A, B_list[cur].transpose(0, 1), out=C_list[cur])
opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=row_group, async_op=True)
cur = 1 - cur
src_b += tesseract_dim
src_c += 1
for op in opr:
op.wait()
if tesseract_dim - 2 == col_rank:
C.copy_(C_list[cur])
if tesseract_dim - 1 == col_rank:
C.copy_(C_list[1 - cur])
out = C.reshape(out_shape) out = C.reshape(out_shape)
if ctx: if ctx:
@ -272,20 +330,52 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
C_shape = (A.shape[-1], B.shape[-1]) C_shape = (A.shape[-1], B.shape[-1])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(tesseract_dim): # use circular buffer to store the communication tensor
A_temp = A.clone() # 2 is enough for all cases
src_a = i + row_rank * tesseract_dim + dep_rank * ( A_list = [torch.empty_like(A) for _ in range(2)]
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ C_list = [torch.empty_like(C) for _ in range(2)]
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(A_temp, src=src_a, group=get_parallel_group(row_parallel_mode))
C_temp = torch.matmul(A_temp.transpose(0, 1), B)
src_c = col_rank + i * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c, group=get_parallel_group(col_parallel_mode))
if i == row_rank:
C = C_temp.clone()
row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode)
src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_c = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2
opr = [None] * 2
A_list[0].copy_(A)
opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True)
cur = 0
for i in range(tesseract_dim):
if i != tesseract_dim - 1:
A_list[1 - cur].copy_(A)
opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True)
if opr[cur] is not None:
opr[cur].wait()
if i - 2 == row_rank:
C.copy_(C_list[cur])
if opa[cur] is not None:
opa[cur].wait()
torch.matmul(A_list[cur].transpose(0, 1), B, out=C_list[cur])
opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=col_group, async_op=True)
cur = 1 - cur
src_a += 1
src_c += tesseract_dim
for op in opr:
op.wait()
if tesseract_dim - 2 == row_rank:
C.copy_(C_list[cur])
if tesseract_dim - 1 == row_rank:
C.copy_(C_list[1 - cur])
out = C.reshape(out_shape) out = C.reshape(out_shape)
if ctx: if ctx:
@ -333,8 +423,7 @@ class Add_Bias_2p5D(torch.autograd.Function):
bias_temp = bias.clone() bias_temp = bias.clone()
else: else:
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device()) bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device())
src_rank = col_rank + dep_rank * ( src_rank = col_rank + dep_rank * tesseract_dim ** 2 + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode)) dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode))
@ -469,7 +558,9 @@ class SplitFirst(torch.autograd.Function):
return grad, None, None return grad, None, None
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor: def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
if input_.size(dim) <= 1:
return input_
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous() dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
@ -477,17 +568,28 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
class reduce_by_batch_2p5d(torch.autograd.Function): class reduce_by_batch_2p5d(torch.autograd.Function):
"""All-reduce the input from the model parallel region.""" """All-reduce the input from the model parallel region."""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_, reduce_mean: bool = False):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL)) output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
return input_ if reduce_mean:
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
return output / reduce_size
return output
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_): def forward(ctx, input_, reduce_mean: bool = False):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL)) output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
return input_.clone() ctx.reduce_mean = reduce_mean
if reduce_mean:
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
ctx.reduce_size = reduce_size
return output.clone() / reduce_size
return output.clone()
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx, grad_output): def backward(ctx, output_grad):
return grad_output if ctx.reduce_mean:
return output_grad / ctx.reduce_size, None
else:
return output_grad, None

View File

@ -13,10 +13,9 @@ from colossalai.utils import get_current_device
from torch import Tensor, dtype from torch import Tensor, dtype
from torch.nn import Parameter from torch.nn import Parameter
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d, from ..utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
split_batch_2p5d) from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d)
from ._utils import (assert_tesseract_initialization, get_tesseract_dim_dep_from_env) from ._utils import (assert_tesseract_initialization, get_tesseract_dim_dep_from_env)
@ -231,7 +230,7 @@ class PatchEmbedding2p5D(ParallelLayer):
self.num_patches = self.grid_size[0] * self.grid_size[1] self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten self.flatten = flatten
self.embed_size = embed_size self.embed_size = embed_size
self.embed_size_per_partition = embed_size // (self.tesseract_dep * self.tesseract_dim**2) self.embed_size_per_partition = embed_size // self.tesseract_dim**2
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
self.weight = Parameter( self.weight = Parameter(
@ -251,10 +250,10 @@ class PatchEmbedding2p5D(ParallelLayer):
self._set_tensor_parallel_attribute() self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self): def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2) set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dep * self.tesseract_dim**2) set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dep * self.tesseract_dim**2) set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dep * self.tesseract_dim**2) set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dim**2)
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer):
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
@ -269,8 +268,6 @@ class PatchEmbedding2p5D(ParallelLayer):
assert H == self.img_size[0] and W == self.img_size[1], \ assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
input_ = split_batch_2p5d(input_)
weight = all_gather_weight_2p5d.apply(self.weight, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) weight = all_gather_weight_2p5d.apply(self.weight, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
bias = all_gather_weight_2p5d.apply(self.bias, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) bias = all_gather_weight_2p5d.apply(self.bias, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
@ -303,7 +300,7 @@ class Embedding2p5D(ParallelLayer):
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim self.embed_dim = embedding_dim
embed_dim_per_partition = embedding_dim // (self.tesseract_dep * self.tesseract_dim**2) embed_dim_per_partition = embedding_dim // self.tesseract_dim**2
self.padding_idx = padding_idx self.padding_idx = padding_idx
self.embed_args = args self.embed_args = args
@ -316,7 +313,7 @@ class Embedding2p5D(ParallelLayer):
self._set_tensor_parallel_attributes() self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self): def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2) set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
def reset_parameters(self, weight_initializer) -> None: def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
@ -330,8 +327,6 @@ class Embedding2p5D(ParallelLayer):
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_2p5d(input_)
weight = all_gather_weight_2p5d.apply(self.weight, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) weight = all_gather_weight_2p5d.apply(self.weight, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
@ -359,7 +354,7 @@ class Classifier2p5D(ParallelLayer):
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
# partitioning dimension # partitioning dimension
self.input_size_per_partition = divide(self.in_features, self.tesseract_dep * self.tesseract_dim**2) self.input_size_per_partition = divide(self.in_features, self.tesseract_dim**2)
if weight is not None: if weight is not None:
self.weight = weight self.weight = weight
@ -378,7 +373,7 @@ class Classifier2p5D(ParallelLayer):
def _set_tensor_parallel_attributes(self): def _set_tensor_parallel_attributes(self):
if self.has_weight: if self.has_weight:
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2) set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
def reset_parameters(self, weight_initializer, bias_initializer) -> None: def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):

View File

@ -1,6 +1,6 @@
from ._operation import reduce_by_batch_3d, split_batch_3d from ._operation import reduce_by_batch_3d, split_tensor_3d
from .layers import Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D from .layers import Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D
__all__ = [ __all__ = [
'reduce_by_batch_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D' 'reduce_by_batch_3d', 'split_tensor_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D'
] ]

View File

@ -175,10 +175,12 @@ class layernorm_3d(torch.autograd.Function):
return input_grad, weight_grad, bias_grad, None, None, None, None, None return input_grad, weight_grad, bias_grad, None, None, None, None, None
def split_batch_3d(input_: Tensor, def split_tensor_3d(input_: Tensor,
input_parallel_mode: ParallelMode, dim: int = 0,
weight_parallel_mode: ParallelMode, input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
dim: int = 0) -> Tensor: weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
if input_.size(dim) <= 1:
return input_
output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode), output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode),
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous() dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
output = torch.chunk(output, gpc.get_world_size(input_parallel_mode), output = torch.chunk(output, gpc.get_world_size(input_parallel_mode),
@ -189,15 +191,27 @@ def split_batch_3d(input_: Tensor,
class reduce_by_batch_3d(torch.autograd.Function): class reduce_by_batch_3d(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode) -> Tensor: def forward(ctx,
input_: Tensor,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
reduce_mean: bool = False) -> Tensor:
output = all_reduce(input_, input_parallel_mode) output = all_reduce(input_, input_parallel_mode)
output = all_reduce(output, weight_parallel_mode) output = all_reduce(output, weight_parallel_mode)
ctx.reduce_mean = reduce_mean
if reduce_mean:
reduce_size = gpc.get_world_size(input_parallel_mode) * gpc.get_world_size(weight_parallel_mode)
ctx.reduce_size = reduce_size
return output.clone() / reduce_size
return output.clone() return output.clone()
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
return output_grad, None, None if ctx.reduce_mean:
return output_grad / ctx.reduce_size, None, None, None
else:
return output_grad, None, None, None
class broadcast_weight_3d_from_diagonal(torch.autograd.Function): class broadcast_weight_3d_from_diagonal(torch.autograd.Function):

View File

@ -17,9 +17,9 @@ from colossalai.utils import get_current_device
from torch import Tensor, dtype from torch import Tensor, dtype
from torch.nn import Parameter from torch.nn import Parameter
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple) from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import * from ._operation import *
from ._utils import (get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group) from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
@LAYERS.register_module @LAYERS.register_module
@ -241,8 +241,6 @@ class PatchEmbedding3D(ParallelLayer):
self.pos_embed.register_hook(self._sync_grad_hook) self.pos_embed.register_hook(self._sync_grad_hook)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode)
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode, weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
self.weight_parallel_mode, self.output_parallel_mode) self.weight_parallel_mode, self.output_parallel_mode)
output = F.conv2d(input_, weight, self.bias, stride=self.patch_size) output = F.conv2d(input_, weight, self.bias, stride=self.patch_size)
@ -302,8 +300,6 @@ class Embedding3D(ParallelLayer):
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode)
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode, weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
self.weight_parallel_mode, self.output_parallel_mode) self.weight_parallel_mode, self.output_parallel_mode)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)

View File

@ -0,0 +1,7 @@
from .common import (ACT2FN, CheckpointModule, _ntuple, divide, get_tensor_parallel_mode,
set_tensor_parallel_attribute_by_partition, set_tensor_parallel_attribute_by_size, to_2tuple)
__all__ = [
'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size',
'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple'
]

View File

@ -2,11 +2,12 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import collections.abc import collections.abc
import os
from itertools import repeat from itertools import repeat
import numpy as np import numpy as np
import torch import torch
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_MODE)
from colossalai.utils import checkpoint from colossalai.utils import checkpoint
from torch import Tensor, nn from torch import Tensor, nn
@ -59,6 +60,10 @@ def set_tensor_parallel_attribute_by_partition(param, num_partitions):
setattr(param, NUM_PARTITIONS, num_partitions) setattr(param, NUM_PARTITIONS, num_partitions)
def get_tensor_parallel_mode():
return os.environ[TENSOR_PARALLEL_MODE]
# From PyTorch internals # From PyTorch internals

View File

@ -9,7 +9,7 @@ from colossalai.utils import get_current_device
from torch import Tensor, dtype from torch import Tensor, dtype
from torch import nn as nn from torch import nn as nn
from .._common_utils import to_2tuple from ..utils import to_2tuple
def drop_path(x, drop_prob: float = 0., training: bool = False): def drop_path(x, drop_prob: float = 0., training: bool = False):

View File

@ -2,6 +2,7 @@ from torch import nn
from torch.nn.modules.loss import * from torch.nn.modules.loss import *
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from colossalai.nn.layer.utils import get_tensor_parallel_mode
from .loss_2d import CrossEntropyLoss2D from .loss_2d import CrossEntropyLoss2D
from .loss_2p5d import CrossEntropyLoss2p5D from .loss_2p5d import CrossEntropyLoss2p5D
from .loss_3d import CrossEntropyLoss3D from .loss_3d import CrossEntropyLoss3D
@ -14,9 +15,10 @@ _parallel_cross_entropy = {
class CrossEntropyLoss(_Loss): class CrossEntropyLoss(_Loss):
def __init__(self, reduction: bool = True, tensor_parallel: str = None, *args, **kwargs): def __init__(self, reduction: bool = True, *args, **kwargs):
super().__init__() super().__init__()
if tensor_parallel in [None, '1d']: tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel in ['None', '1d']:
reduction = 'mean' if reduction else 'none' reduction = 'mean' if reduction else 'none'
self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs) self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
else: else:

View File

@ -1,4 +1,4 @@
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
from colossalai.registry import LOSSES from colossalai.registry import LOSSES
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
@ -20,11 +20,8 @@ class CrossEntropyLoss2D(_Loss):
self.loss_kwargs = kwargs self.loss_kwargs = kwargs
def forward(self, logits, targets): def forward(self, logits, targets):
batch_size = targets.size(0) loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
targets = split_batch_2d(targets)
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean: if self.reduction_mean:
loss = loss.sum() loss = loss.mean()
loss = reduce_by_batch_2d.apply(loss) loss = reduce_by_batch_2d.apply(loss, True)
loss /= batch_size
return loss return loss

View File

@ -1,4 +1,4 @@
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
from colossalai.registry import LOSSES from colossalai.registry import LOSSES
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
@ -19,11 +19,8 @@ class CrossEntropyLoss2p5D(_Loss):
self.loss_kwargs = kwargs self.loss_kwargs = kwargs
def forward(self, logits, targets): def forward(self, logits, targets):
batch_size = targets.size(0) loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
targets = split_batch_2p5d(targets)
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean: if self.reduction_mean:
loss = loss.sum() loss = loss.mean()
loss = reduce_by_batch_2p5d.apply(loss) loss = reduce_by_batch_2p5d.apply(loss, True)
loss /= batch_size
return loss return loss

View File

@ -1,11 +1,10 @@
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_batch_3d from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.registry import LOSSES from colossalai.registry import LOSSES
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
@LOSSES.register_module @LOSSES.register_module
class CrossEntropyLoss3D(_Loss): class CrossEntropyLoss3D(_Loss):
"""Cross entropy loss for 3D parallelism """Cross entropy loss for 3D parallelism
@ -28,11 +27,8 @@ class CrossEntropyLoss3D(_Loss):
self.loss_kwargs = kwargs self.loss_kwargs = kwargs
def forward(self, logits, targets): def forward(self, logits, targets):
batch_size = targets.size(0) loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
targets = split_batch_3d(targets, self.input_parallel_mode, self.weight_parallel_mode)
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean: if self.reduction_mean:
loss = loss.sum() loss = loss.mean()
loss = reduce_by_batch_3d.apply(loss, self.input_parallel_mode, self.weight_parallel_mode) loss = reduce_by_batch_3d.apply(loss, self.input_parallel_mode, self.weight_parallel_mode, True)
loss /= batch_size
return loss return loss

View File

@ -4,6 +4,7 @@ from ._utils import calc_acc
from .accuracy_2d import Accuracy2D from .accuracy_2d import Accuracy2D
from .accuracy_2p5d import Accuracy2p5D from .accuracy_2p5d import Accuracy2p5D
from .accuracy_3d import Accuracy3D from .accuracy_3d import Accuracy3D
from colossalai.nn.layer.utils import get_tensor_parallel_mode
_parallel_accuracy = { _parallel_accuracy = {
'2d': Accuracy2D, '2d': Accuracy2D,
@ -13,9 +14,10 @@ _parallel_accuracy = {
class Accuracy(nn.Module): class Accuracy(nn.Module):
def __init__(self, tensor_parallel: str = None): def __init__(self):
super().__init__() super().__init__()
if tensor_parallel in [None, '1d']: tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel in ['None', '1d']:
self.acc = calc_acc self.acc = calc_acc
else: else:
self.acc = _parallel_accuracy[tensor_parallel]() self.acc = _parallel_accuracy[tensor_parallel]()

View File

@ -1,5 +1,5 @@
import torch import torch
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d
from torch import nn from torch import nn
from ._utils import calc_acc from ._utils import calc_acc
@ -11,7 +11,6 @@ class Accuracy2D(nn.Module):
def forward(self, logits, targets): def forward(self, logits, targets):
with torch.no_grad(): with torch.no_grad():
targets = split_batch_2d(targets)
correct = calc_acc(logits, targets) correct = calc_acc(logits, targets)
correct = reduce_by_batch_2d.apply(correct) correct = reduce_by_batch_2d.apply(correct)
return correct return correct

View File

@ -1,5 +1,5 @@
import torch import torch
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d
from torch import nn from torch import nn
from ._utils import calc_acc from ._utils import calc_acc
@ -11,7 +11,6 @@ class Accuracy2p5D(nn.Module):
def forward(self, logits, targets): def forward(self, logits, targets):
with torch.no_grad(): with torch.no_grad():
targets = split_batch_2p5d(targets)
correct = calc_acc(logits, targets) correct = calc_acc(logits, targets)
correct = reduce_by_batch_2p5d.apply(correct) correct = reduce_by_batch_2p5d.apply(correct)
return correct return correct

View File

@ -1,6 +1,6 @@
import torch import torch
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_batch_3d from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from torch import nn from torch import nn
@ -15,7 +15,6 @@ class Accuracy3D(nn.Module):
def forward(self, logits, targets): def forward(self, logits, targets):
with torch.no_grad(): with torch.no_grad():
targets = split_batch_3d(targets, self.input_parallel_mode, self.weight_parallel_mode)
correct = calc_acc(logits, targets) correct = calc_acc(logits, targets)
correct = reduce_by_batch_3d.apply(correct, self.input_parallel_mode, self.weight_parallel_mode) correct = reduce_by_batch_3d.apply(correct, self.input_parallel_mode, self.weight_parallel_mode)
return correct return correct

View File

@ -173,7 +173,7 @@ class AccuracyMetric(Metric):
self.accumulated_sum.zero_() self.accumulated_sum.zero_()
self.accumulated_correct.zero_() self.accumulated_correct.zero_()
def update(self, logits, targets) -> None: def update(self, logits, targets, batch_size) -> None:
"""Updates last step accuracy and accumulated accuracy with current logits """Updates last step accuracy and accumulated accuracy with current logits
and labels. It expects the output has logits and labels. and labels. It expects the output has logits and labels.
@ -187,7 +187,7 @@ class AccuracyMetric(Metric):
# update # update
correct = self.acc(logits, targets) correct = self.acc(logits, targets)
self.last_step_sum.fill_(targets.size(0)) self.last_step_sum.fill_(batch_size)
self.last_step_correct.fill_(correct) self.last_step_correct.fill_(correct)
self.accumulated_sum += self.last_step_sum self.accumulated_sum += self.last_step_sum
self.accumulated_correct += self.last_step_correct self.accumulated_correct += self.last_step_correct
@ -296,7 +296,8 @@ class AccuracyHook(MetricHook):
def after_test_iter(self, trainer, logits, targets, *args): def after_test_iter(self, trainer, logits, targets, *args):
if self._is_stage_to_compute: if self._is_stage_to_compute:
self.metric.update(logits, targets) batch_size = trainer.schedule.batch_size
self.metric.update(logits, targets, batch_size)
class ThroughputMetric(Metric): class ThroughputMetric(Metric):
@ -313,10 +314,8 @@ class ThroughputMetric(Metric):
self.last_step_num_samples.zero_() self.last_step_num_samples.zero_()
self.last_step_used_time.zero_() self.last_step_used_time.zero_()
def update(self, tensor, time) -> None: def update(self, num_samples, time) -> None:
if isinstance(tensor, (list, tuple)): self.last_step_num_samples.fill_(num_samples)
tensor = tensor[0]
self.last_step_num_samples.fill_(tensor.size(0))
self.last_step_used_time.fill_(time) self.last_step_used_time.fill_(time)
self.accumulated_num_samples += self.last_step_num_samples self.accumulated_num_samples += self.last_step_num_samples
self.accumulated_used_time += self.last_step_used_time self.accumulated_used_time += self.last_step_used_time
@ -354,11 +353,11 @@ class ThroughputHook(MetricHook):
def before_train_epoch(self, trainer): def before_train_epoch(self, trainer):
self.metric.reset() self.metric.reset()
def after_train_iter(self, trainer, logits, targets, *args): def after_train_iter(self, trainer, *args):
self.metric.update(targets, trainer._timer.get_timer('Train-step').get_elapsed_time()) self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
def before_test(self, trainer): def before_test(self, trainer):
self.metric.reset() self.metric.reset()
def after_test_iter(self, trainer, logits, targets, *args): def after_test_iter(self, trainer, *args):
self.metric.update(targets, trainer._timer.get_timer('Test-step').get_elapsed_time()) self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())

View File

@ -1,27 +1,19 @@
from .activation_checkpoint import checkpoint from .activation_checkpoint import checkpoint
from .common import (print_rank_0, sync_model_param_in_dp, is_dp_rank_0, from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
is_tp_rank_0, is_no_pp_or_last_stage, is_using_ddp, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0,
is_using_pp, conditional_context, is_model_parallel_parameter, is_using_ddp, is_using_pp, multi_tensor_applier, param_is_not_tensor_parallel_duplicate,
clip_grad_norm_fp32, count_zeros_fp32, copy_tensor_parallel_attributes, print_rank_0, switch_virtual_pipeline_parallel_rank, sync_model_param_in_dp)
param_is_not_tensor_parallel_duplicate, switch_virtual_pipeline_parallel_rank) from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
from .cuda import get_current_device, synchronize, empty_cache, set_to_cuda from .data_sampler import DataParallelSampler, get_dataloader
from .gradient_accumulation import accumulate_gradient
from .memory import report_memory_usage from .memory import report_memory_usage
from .timer import MultiTimer, Timer from .timer import MultiTimer, Timer
from .multi_tensor_apply import multi_tensor_applier
from .gradient_accumulation import accumulate_gradient
from .data_sampler import DataParallelSampler, get_dataloader
__all__ = ['checkpoint', __all__ = [
'print_rank_0', 'sync_model_param_in_dp', 'is_dp_rank_0', 'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param_in_dp', 'is_dp_rank_0', 'is_tp_rank_0',
'is_tp_rank_0', 'is_no_pp_or_last_stage', 'is_using_ddp', 'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'conditional_context', 'is_model_parallel_parameter',
'is_using_pp', 'conditional_context', 'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes', 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
'param_is_not_tensor_parallel_duplicate', 'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda', 'get_dataloader', 'switch_virtual_pipeline_parallel_rank'
'report_memory_usage', ]
'Timer', 'MultiTimer',
'multi_tensor_applier',
'accumulate_gradient',
'DataParallelSampler', 'get_dataloader',
'switch_virtual_pipeline_parallel_rank'
]

View File

@ -1,5 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import random
import socket
import torch import torch
from torch._six import inf from torch._six import inf
@ -9,16 +11,15 @@ try:
except: except:
pass pass
import torch.distributed as dist
from contextlib import contextmanager from contextlib import contextmanager
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from .multi_tensor_apply import multi_tensor_applier
from colossalai.constants import IS_TENSOR_PARALLEL, TENSOR_PARALLEL_ATTRIBUTES, NUM_PARTITIONS
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from .multi_tensor_apply import multi_tensor_applier
def print_rank_0(msg: str, logger=None): def print_rank_0(msg: str, logger=None):
'''Print messages and save logs(optional). This is executed only if you are the rank-0 gpu. '''Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
@ -33,6 +34,18 @@ def print_rank_0(msg: str, logger=None):
logger.info(msg) logger.info(msg)
def free_port():
while True:
try:
sock = socket.socket()
port = random.randint(20000, 65000)
sock.bind(('localhost', port))
sock.close()
return port
except Exception:
continue
def sync_model_param_in_dp(model): def sync_model_param_in_dp(model):
'''Make sure data parameters are consistent during Data Parallel Mode '''Make sure data parameters are consistent during Data Parallel Mode

View File

@ -3,9 +3,8 @@ from typing import Callable
import torch import torch
from colossalai import nn as col_nn from colossalai import nn as col_nn
from colossalai.context import ParallelMode, seed from colossalai.nn.layer.utils import CheckpointModule
from colossalai.registry import LAYERS, MODELS from colossalai.registry import LAYERS, MODELS
from colossalai.utils import checkpoint
from torch import dtype, nn from torch import dtype, nn
__all__ = [ __all__ = [
@ -72,8 +71,7 @@ class ViTEmbedding(nn.Module):
dropout: float, dropout: float,
dtype: dtype = None, dtype: dtype = None,
flatten: bool = True, flatten: bool = True,
init_method: str = 'torch', init_method: str = 'torch'):
tensor_parallel: str = None):
super().__init__() super().__init__()
self.patch_embed = col_nn.PatchEmbedding(img_size, self.patch_embed = col_nn.PatchEmbedding(img_size,
patch_size, patch_size,
@ -81,19 +79,17 @@ class ViTEmbedding(nn.Module):
embedding_dim, embedding_dim,
dtype=dtype, dtype=dtype,
flatten=flatten, flatten=flatten,
tensor_parallel=tensor_parallel,
**_init_rules[init_method]['embed']) **_init_rules[init_method]['embed'])
self.dropout = nn.Dropout(dropout) self.dropout = col_nn.Dropout(dropout)
def forward(self, x): def forward(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
with seed(ParallelMode.TENSOR): x = self.dropout(x)
x = self.dropout(x)
return x return x
@LAYERS.register_module @LAYERS.register_module
class ViTSelfAttention(nn.Module): class ViTSelfAttention(CheckpointModule):
def __init__(self, def __init__(self,
dim: int, dim: int,
num_heads: int, num_heads: int,
@ -102,27 +98,17 @@ class ViTSelfAttention(nn.Module):
bias: bool = True, bias: bool = True,
dtype: dtype = None, dtype: dtype = None,
checkpoint: bool = False, checkpoint: bool = False,
init_method: str = 'torch', init_method: str = 'torch'):
tensor_parallel: str = None): super().__init__(checkpoint)
super().__init__()
self.attention_head_size = dim // num_heads self.attention_head_size = dim // num_heads
self.checkpoint = checkpoint
self.tensor_parallel = tensor_parallel
self.query_key_value = col_nn.Linear(dim, self.query_key_value = col_nn.Linear(dim,
3 * dim, 3 * dim,
dtype=dtype, dtype=dtype,
bias=bias, bias=bias,
tensor_parallel='1d_col' if tensor_parallel == '1d' else tensor_parallel,
**_init_rules[init_method]['transformer']) **_init_rules[init_method]['transformer'])
self.attention_dropout = nn.Dropout(attention_dropout) self.attention_dropout = col_nn.Dropout(attention_dropout)
self.dense = col_nn.Linear(dim, self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True, **_init_rules[init_method]['transformer'])
dim, self.dropout = col_nn.Dropout(dropout)
dtype=dtype,
bias=True,
tensor_parallel='1d_row' if tensor_parallel == '1d' else tensor_parallel,
**_init_rules[init_method]['transformer'])
self.dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1) self.softmax = nn.Softmax(dim=-1)
def _forward(self, x): def _forward(self, x):
@ -138,8 +124,7 @@ class ViTSelfAttention(nn.Module):
x = torch.matmul(q, k.transpose(-1, -2)) x = torch.matmul(q, k.transpose(-1, -2))
x = x / math.sqrt(self.attention_head_size) x = x / math.sqrt(self.attention_head_size)
x = self.softmax(x) x = self.softmax(x)
with seed(ParallelMode.TENSOR): x = self.attention_dropout(x)
x = self.attention_dropout(x)
x = torch.matmul(x, v) x = torch.matmul(x, v)
x = x.transpose(1, 2) x = x.transpose(1, 2)
@ -147,26 +132,13 @@ class ViTSelfAttention(nn.Module):
x = x.reshape(new_context_layer_shape) x = x.reshape(new_context_layer_shape)
x = self.dense(x) x = self.dense(x)
if self.tensor_parallel == '1d': x = self.dropout(x)
x = self.dropout(x)
else:
with seed(ParallelMode.TENSOR):
x = self.dropout(x)
return x return x
def _checkpoint_forward(self, x):
return checkpoint(self._forward, x)
def forward(self, x):
if self.checkpoint:
return self._checkpoint_forward(x)
else:
return self._forward(x)
@LAYERS.register_module @LAYERS.register_module
class ViTMLP(nn.Module): class ViTMLP(CheckpointModule):
def __init__(self, def __init__(self,
dim: int, dim: int,
mlp_ratio: int, mlp_ratio: int,
@ -175,50 +147,30 @@ class ViTMLP(nn.Module):
dtype: dtype = None, dtype: dtype = None,
bias: bool = True, bias: bool = True,
checkpoint: bool = False, checkpoint: bool = False,
init_method: str = 'torch', init_method: str = 'torch'):
tensor_parallel: str = None): super().__init__(checkpoint)
super().__init__()
self.checkpoint = checkpoint
self.tensor_parallel = tensor_parallel
self.dense_1 = col_nn.Linear(dim, self.dense_1 = col_nn.Linear(dim,
mlp_ratio * dim, mlp_ratio * dim,
dtype=dtype, dtype=dtype,
bias=bias, bias=bias,
tensor_parallel='1d_col' if tensor_parallel == '1d' else tensor_parallel,
**_init_rules[init_method]['transformer']) **_init_rules[init_method]['transformer'])
self.activation = activation self.activation = activation
self.dropout_1 = col_nn.Dropout(dropout)
self.dense_2 = col_nn.Linear(mlp_ratio * dim, self.dense_2 = col_nn.Linear(mlp_ratio * dim,
dim, dim,
dtype=dtype, dtype=dtype,
bias=bias, bias=bias,
tensor_parallel='1d_row' if tensor_parallel == '1d' else tensor_parallel,
**_init_rules[init_method]['transformer']) **_init_rules[init_method]['transformer'])
self.dropout = nn.Dropout(dropout) self.dropout_2 = col_nn.Dropout(dropout)
def _forward(self, x): def _forward(self, x):
x = self.dense_1(x) x = self.dense_1(x)
x = self.activation(x) x = self.activation(x)
with seed(ParallelMode.TENSOR): x = self.dropout_1(x)
x = self.dropout(x)
x = self.dense_2(x) x = self.dense_2(x)
if self.tensor_parallel == '1d': x = self.dropout_2(x)
x = self.dropout(x)
else:
with seed(ParallelMode.TENSOR):
x = self.dropout(x)
return x return x
def _checkpoint_forward(self, x):
return checkpoint(self._forward, x)
def forward(self, x):
if self.checkpoint:
return self._checkpoint_forward(x)
else:
return self._forward(x)
@LAYERS.register_module @LAYERS.register_module
class ViTHead(nn.Module): class ViTHead(nn.Module):
@ -228,19 +180,14 @@ class ViTHead(nn.Module):
representation_size: int = None, representation_size: int = None,
dtype: dtype = None, dtype: dtype = None,
bias: bool = True, bias: bool = True,
init_method: str = 'torch', init_method: str = 'torch'):
tensor_parallel: str = None):
super().__init__() super().__init__()
if representation_size: if representation_size:
tensor_parallel_kwargs = {'tensor_parallel': '1d_col' if tensor_parallel == '1d' else tensor_parallel}
if tensor_parallel == '1d':
tensor_parallel_kwargs['gather_output'] = True
self.representation = col_nn.Linear(dim, self.representation = col_nn.Linear(dim,
representation_size, representation_size,
bias=bias, bias=bias,
dtype=dtype, dtype=dtype,
**_init_rules[init_method]['head'], **_init_rules[init_method]['head'])
**tensor_parallel_kwargs)
else: else:
self.representation = None self.representation = None
representation_size = dim representation_size = dim
@ -249,7 +196,6 @@ class ViTHead(nn.Module):
num_classes, num_classes,
dtype=dtype, dtype=dtype,
bias=bias, bias=bias,
tensor_parallel=tensor_parallel,
**_init_rules[init_method]['head']) **_init_rules[init_method]['head'])
def forward(self, x): def forward(self, x):
@ -273,10 +219,9 @@ class ViTBlock(nn.Module):
dtype: dtype = None, dtype: dtype = None,
bias: bool = True, bias: bool = True,
checkpoint: bool = False, checkpoint: bool = False,
init_method: str = 'torch', init_method: str = 'torch'):
tensor_parallel: str = None):
super().__init__() super().__init__()
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype, tensor_parallel=tensor_parallel) self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
self.attn = ViTSelfAttention(dim=dim, self.attn = ViTSelfAttention(dim=dim,
num_heads=num_heads, num_heads=num_heads,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
@ -284,10 +229,9 @@ class ViTBlock(nn.Module):
bias=bias, bias=bias,
dtype=dtype, dtype=dtype,
checkpoint=checkpoint, checkpoint=checkpoint,
init_method=init_method, init_method=init_method)
tensor_parallel=tensor_parallel)
self.drop_path = col_nn.DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = col_nn.DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype, tensor_parallel=tensor_parallel) self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
self.mlp = ViTMLP(dim=dim, self.mlp = ViTMLP(dim=dim,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
activation=activation, activation=activation,
@ -295,8 +239,7 @@ class ViTBlock(nn.Module):
dtype=dtype, dtype=dtype,
bias=bias, bias=bias,
checkpoint=checkpoint, checkpoint=checkpoint,
init_method=init_method, init_method=init_method)
tensor_parallel=tensor_parallel)
def forward(self, x): def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.attn(self.norm1(x)))
@ -323,20 +266,16 @@ class VisionTransformer(nn.Module):
dtype: dtype = None, dtype: dtype = None,
bias: bool = True, bias: bool = True,
checkpoint: bool = False, checkpoint: bool = False,
init_method: str = 'torch', init_method: str = 'torch'):
tensor_parallel: str = None):
super().__init__() super().__init__()
embed = ViTEmbedding( embed = ViTEmbedding(img_size=img_size,
img_size=img_size, patch_size=patch_size,
patch_size=patch_size, in_chans=in_chans,
in_chans=in_chans, embedding_dim=dim,
embedding_dim=dim, dropout=dropout,
dropout=dropout, dtype=dtype,
dtype=dtype, init_method=init_method)
init_method=init_method,
tensor_parallel=tensor_parallel,
)
# stochastic depth decay rule # stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
@ -353,26 +292,17 @@ class VisionTransformer(nn.Module):
bias=bias, bias=bias,
checkpoint=checkpoint, checkpoint=checkpoint,
init_method=init_method, init_method=init_method,
tensor_parallel=tensor_parallel,
) for i in range(depth) ) for i in range(depth)
] ]
norm = col_nn.LayerNorm( norm = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
normalized_shape=dim,
eps=1e-6,
dtype=dtype,
tensor_parallel=tensor_parallel,
)
head = ViTHead( head = ViTHead(dim=dim,
dim=dim, num_classes=num_classes,
num_classes=num_classes, representation_size=representation_size,
representation_size=representation_size, dtype=dtype,
dtype=dtype, bias=bias,
bias=bias, init_method=init_method)
init_method=init_method,
tensor_parallel=tensor_parallel,
)
self.layers = nn.Sequential( self.layers = nn.Sequential(
embed, embed,

View File

@ -1,4 +1,3 @@
import time
from functools import partial from functools import partial
import pytest import pytest
@ -9,7 +8,7 @@ from colossalai.communication import all_gather, all_reduce, reduce_scatter
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.utils import get_current_device from colossalai.utils import free_port, get_current_device
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1))) CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
@ -49,8 +48,8 @@ def check_all_reduce():
torch.cuda.synchronize() torch.cuda.synchronize()
def check_layer(rank, world_size): def check_layer(rank, world_size, port):
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=30010, backend='nccl') launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
assert dist.get_rank() == gpc.get_global_rank() assert dist.get_rank() == gpc.get_global_rank()
print('Rank {} / {}'.format(dist.get_rank(), dist.get_world_size())) print('Rank {} / {}'.format(dist.get_rank(), dist.get_world_size()))
@ -66,7 +65,7 @@ def check_layer(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
def test_comm(): def test_comm():
world_size = 4 world_size = 4
run_func = partial(check_layer, world_size=world_size) run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -1,15 +1,16 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from functools import partial
from pathlib import Path
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai import launch from colossalai import launch
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from functools import partial from colossalai.utils import free_port
from pathlib import Path
CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2d_init.py').absolute() CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2d_init.py').absolute()
@ -87,7 +88,7 @@ def test_2d_init():
test_fn = partial(init_2d, test_fn = partial(init_2d,
world_size=world_size, world_size=world_size,
backend='gloo', backend='gloo',
port='29900', port=free_port(),
host='localhost' host='localhost'
) )
mp.spawn(test_fn, nprocs=world_size) mp.spawn(test_fn, nprocs=world_size)

View File

@ -7,10 +7,10 @@ from pathlib import Path
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.utils import free_port
CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2p5d_init.py').absolute() CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2p5d_init.py').absolute()
@ -111,7 +111,7 @@ def test_2halfd_init():
test_fn = partial(init_2halfd, test_fn = partial(init_2halfd,
world_size=world_size, world_size=world_size,
backend='gloo', backend='gloo',
port='29901', port=free_port(),
host='localhost' host='localhost'
) )
mp.spawn(test_fn, nprocs=world_size) mp.spawn(test_fn, nprocs=world_size)

View File

@ -7,11 +7,10 @@ from pathlib import Path
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.utils import free_port
CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_3d_init.py').absolute() CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_3d_init.py').absolute()
@ -104,7 +103,7 @@ def test_3d_init():
test_fn = partial(init_3d, test_fn = partial(init_3d,
world_size=world_size, world_size=world_size,
backend='gloo', backend='gloo',
port='29902', port=free_port(),
host='localhost' host='localhost'
) )
mp.spawn(test_fn, nprocs=world_size) mp.spawn(test_fn, nprocs=world_size)

View File

@ -13,7 +13,7 @@ from colossalai.logging import get_dist_logger
from colossalai.nn import Accuracy, LinearWarmupLR from colossalai.nn import Accuracy, LinearWarmupLR
from colossalai.nn.loss import CrossEntropyLoss from colossalai.nn.loss import CrossEntropyLoss
from colossalai.trainer import Trainer, hooks from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader from colossalai.utils import MultiTimer, free_port, get_dataloader
from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep
from model_zoo.vit import vit_tiny_patch4_32 from model_zoo.vit import vit_tiny_patch4_32
from torchvision import transforms from torchvision import transforms
@ -27,12 +27,12 @@ CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
gradient_accumulation=2) gradient_accumulation=2)
def run_trainer(rank, world_size): def run_trainer(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=30000, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
logger = get_dist_logger() logger = get_dist_logger()
model = vit_tiny_patch4_32(tensor_parallel='1d') model = vit_tiny_patch4_32()
pipe_model = build_pipeline_model(model.layers, num_chunks=1) pipe_model = build_pipeline_model(model.layers, num_chunks=1)
# build dataloaders # build dataloaders
@ -54,7 +54,7 @@ def run_trainer(rank, world_size):
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=BATCH_SIZE, pin_memory=True) test_dataloader = get_dataloader(dataset=test_dataset, batch_size=BATCH_SIZE, pin_memory=True)
# build criterion # build criterion
criterion = CrossEntropyLoss(tensor_parallel='1d') criterion = CrossEntropyLoss()
# optimizer # optimizer
optimizer = torch.optim.Adam(pipe_model.parameters(), lr=0.001, weight_decay=0) optimizer = torch.optim.Adam(pipe_model.parameters(), lr=0.001, weight_decay=0)
@ -78,7 +78,6 @@ def run_trainer(rank, world_size):
hook_list = [ hook_list = [
hooks.LossHook(), hooks.LossHook(),
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
hooks.AccuracyHook(accuracy_func=Accuracy(tensor_parallel='1d')),
hooks.LogMetricByEpochHook(logger), hooks.LogMetricByEpochHook(logger),
] ]
@ -95,7 +94,7 @@ def run_trainer(rank, world_size):
# @pytest.mark.skip("This test requires more than 8 GPUs, you should invoke this test script using test.sh provided manually") # @pytest.mark.skip("This test requires more than 8 GPUs, you should invoke this test script using test.sh provided manually")
def test_hybrid_parallel(): def test_hybrid_parallel():
world_size = 8 world_size = 8
run_func = partial(run_trainer, world_size=world_size) run_func = partial(run_trainer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -1,25 +1,23 @@
# !/usr/bin/env python # !/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import colossalai
import os import os
from functools import partial
from pathlib import Path
import colossalai
import pytest import pytest
import torch import torch
import os.path as osp
from pathlib import Path
import torch.nn as nn
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn
from torchvision import transforms
from torch.optim import Adam
from colossalai.core import global_context as gpc
from colossalai.amp import AMP_TYPE from colossalai.amp import AMP_TYPE
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import report_memory_usage, get_dataloader from colossalai.utils import free_port, get_dataloader, report_memory_usage
from torchvision.models import resnet18 from torch.optim import Adam
from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from functools import partial from torchvision.models import resnet18
# Config # Config
BATCH_SIZE = 128 BATCH_SIZE = 128
@ -38,14 +36,14 @@ CONFIG = dict(
) )
def run_engine(rank, world_size): def run_engine(rank, world_size, port):
# init dist env # init dist env
colossalai.launch( colossalai.launch(
config=CONFIG, config=CONFIG,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
host='localhost', host='localhost',
port=29910, port=port,
backend='nccl' backend='nccl'
) )
@ -104,7 +102,7 @@ def run_engine(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
def test_engine(): def test_engine():
world_size = 4 world_size = 4
run_func = partial(run_engine, world_size=world_size) run_func = partial(run_engine, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -1,23 +1,20 @@
import colossalai
import os import os
from functools import partial
from pathlib import Path
import colossalai
import pytest import pytest
import torch import torch
import os.path as osp
from pathlib import Path
import torch.nn as nn
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn
from torchvision import transforms
from torch.optim import Adam
from colossalai.core import global_context as gpc
from colossalai.amp import AMP_TYPE from colossalai.amp import AMP_TYPE
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import report_memory_usage, get_dataloader from colossalai.utils import free_port, get_dataloader, report_memory_usage
from colossalai.initialize import get_default_parser from torch.optim import Adam
from torchvision.models import resnet18 from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from functools import partial from torchvision.models import resnet18
# Config # Config
BATCH_SIZE = 128 BATCH_SIZE = 128
@ -38,14 +35,14 @@ CONFIG = dict(
) )
def run_engine(rank, world_size): def run_engine(rank, world_size, port):
# init dist env # init dist env
colossalai.launch( colossalai.launch(
config=CONFIG, config=CONFIG,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
host='localhost', host='localhost',
port=29911, port=port,
backend='nccl' backend='nccl'
) )
@ -104,7 +101,7 @@ def run_engine(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
def test_engine(): def test_engine():
world_size = 4 world_size = 4
run_func = partial(run_engine, world_size=world_size) run_func = partial(run_engine, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -1,23 +1,19 @@
import colossalai
import os import os
from functools import partial
from pathlib import Path
import colossalai
import pytest import pytest
import torch import torch
import os.path as osp
from pathlib import Path
import torch.nn as nn
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn
from torchvision import transforms
from torch.optim import Adam
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.amp import AMP_TYPE
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import report_memory_usage, get_dataloader from colossalai.utils import free_port, get_dataloader, report_memory_usage
from colossalai.initialize import get_default_parser from torch.optim import Adam
from torchvision.models import resnet18 from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from functools import partial from torchvision.models import resnet18
# Config # Config
BATCH_SIZE = 128 BATCH_SIZE = 128
@ -35,14 +31,14 @@ CONFIG = dict(
) )
def run_engine(rank, world_size): def run_engine(rank, world_size, port):
# init dist env # init dist env
colossalai.launch( colossalai.launch(
config=CONFIG, config=CONFIG,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
host='localhost', host='localhost',
port=29912, port=port,
backend='nccl' backend='nccl'
) )
@ -101,7 +97,7 @@ def run_engine(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
def test_engine(): def test_engine():
world_size = 4 world_size = 4
run_func = partial(run_engine, world_size=world_size) run_func = partial(run_engine, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -1,23 +1,20 @@
import colossalai
import os import os
from functools import partial
from pathlib import Path
import colossalai
import pytest import pytest
import torch import torch
import os.path as osp
from pathlib import Path
import torch.nn as nn
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn
from torchvision import transforms
from torch.optim import Adam
from colossalai.core import global_context as gpc
from colossalai.amp import AMP_TYPE from colossalai.amp import AMP_TYPE
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import report_memory_usage, get_dataloader from colossalai.utils import free_port, get_dataloader, report_memory_usage
from colossalai.initialize import get_default_parser from torch.optim import Adam
from torchvision.models import resnet18 from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from functools import partial from torchvision.models import resnet18
# Config # Config
BATCH_SIZE = 128 BATCH_SIZE = 128
@ -36,14 +33,14 @@ CONFIG = dict(
) )
def run_engine(rank, world_size): def run_engine(rank, world_size, port):
# init dist env # init dist env
colossalai.launch( colossalai.launch(
config=CONFIG, config=CONFIG,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
host='localhost', host='localhost',
port=29913, port=port,
backend='nccl' backend='nccl'
) )
@ -102,7 +99,7 @@ def run_engine(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
def test_engine(): def test_engine():
world_size = 4 world_size = 4
run_func = partial(run_engine, world_size=world_size) run_func = partial(run_engine, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -1,13 +1,15 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from functools import partial from colossalai.utils import free_port
from checks_1d.check_layer_1d import * from checks_1d.check_layer_1d import *
CONFIG = dict( CONFIG = dict(
@ -21,12 +23,12 @@ CONFIG = dict(
) )
def check_layer(rank, world_size): def check_layer(rank, world_size, port):
launch(config=CONFIG, launch(config=CONFIG,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
host='localhost', host='localhost',
port=29920, port=port,
backend='nccl') backend='nccl')
check_linear_col() check_linear_col()
@ -39,7 +41,7 @@ def check_layer(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
def test_1d(): def test_1d():
world_size = 4 world_size = 4
run_func = partial(check_layer, world_size=world_size) run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -1,16 +1,17 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.utils import free_port
from checks_2d.check_layer_2d import * from checks_2d.check_layer_2d import *
from checks_2d.check_operation_2d import * from checks_2d.check_operation_2d import *
from functools import partial
CONFIG = dict( CONFIG = dict(
parallel=dict( parallel=dict(
@ -34,12 +35,12 @@ def check_layer():
check_layernorm() check_layernorm()
check_classifier() check_classifier()
def check_layer_and_operation(rank, world_size): def check_layer_and_operation(rank, world_size, port):
launch(config=CONFIG, launch(config=CONFIG,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
host='localhost', host='localhost',
port=29921, port=port,
backend='nccl') backend='nccl')
# check_operations() # check_operations()
@ -51,7 +52,7 @@ def check_layer_and_operation(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
def test_2d(): def test_2d():
world_size = 4 world_size = 4
run_func = partial(check_layer_and_operation, world_size=world_size) run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -1,13 +1,15 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from checks_2p5d.check_layer_2p5d import check_linear, check_layernorm, check_classifier from colossalai.utils import free_port
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
from functools import partial
from checks_2p5d.check_layer_2p5d import (check_classifier, check_layernorm,
check_linear)
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
CONFIG = dict( CONFIG = dict(
parallel=dict( parallel=dict(
@ -29,12 +31,12 @@ def check_layer():
check_classifier() check_classifier()
def check_layer_and_operation(rank, world_size): def check_layer_and_operation(rank, world_size, port):
launch(config=CONFIG, launch(config=CONFIG,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
host='localhost', host='localhost',
port=29922, port=port,
backend='nccl') backend='nccl')
check_operations() check_operations()
@ -46,7 +48,7 @@ def check_layer_and_operation(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
def test_2p5d(): def test_2p5d():
world_size = 4 world_size = 4
run_func = partial(check_layer_and_operation, world_size=world_size) run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -7,6 +7,7 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.utils import free_port
from checks_3d.check_layer_3d import * from checks_3d.check_layer_3d import *
@ -27,8 +28,8 @@ def check_layer():
# check_loss() # check_loss()
def check_layer_and_operation(rank, world_size): def check_layer_and_operation(rank, world_size, port):
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29923, backend='nccl') launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_layer() check_layer()
gpc.destroy() gpc.destroy()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -37,7 +38,7 @@ def check_layer_and_operation(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
def test_3d(): def test_3d():
world_size = 8 world_size = 8
run_func = partial(check_layer_and_operation, world_size=world_size) run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -4,10 +4,11 @@
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.initialize import launch, get_default_parser from colossalai.initialize import launch
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from checks_seq.check_layer_seq import * from checks_seq.check_layer_seq import *
from functools import partial from functools import partial
from colossalai.utils import free_port
CONFIG = dict( CONFIG = dict(
@ -22,13 +23,13 @@ def check_layer():
check_selfattention() check_selfattention()
def run_check_sequence(rank, world_size): def run_check_sequence(rank, world_size, port):
# init dist # init dist
launch(config=CONFIG, launch(config=CONFIG,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
host='localhost', host='localhost',
port=29924, port=port,
backend='nccl') backend='nccl')
logger = get_dist_logger() logger = get_dist_logger()
logger.info('Distributed environment is initialzied.', ranks=[0]) logger.info('Distributed environment is initialzied.', ranks=[0])
@ -41,7 +42,7 @@ def run_check_sequence(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
def test_sequence(): def test_sequence():
world_size = 4 world_size = 4
run_func = partial(run_check_sequence, world_size=world_size) run_func = partial(run_check_sequence, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -1,4 +1,5 @@
import os import os
import model
from pathlib import Path from pathlib import Path
BATCH_SIZE = 128 BATCH_SIZE = 128

View File

@ -1,11 +1,12 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from functools import partial
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.communication import (recv_backward, recv_forward, from colossalai.communication import (recv_backward, recv_forward,
recv_tensor_meta, send_backward, recv_tensor_meta, send_backward,
send_backward_recv_forward, send_forward, send_backward_recv_forward, send_forward,
@ -15,8 +16,7 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device from colossalai.utils import free_port, get_current_device
from functools import partial
BATCH_SIZE = 16 BATCH_SIZE = 16
SEQ_LENGTH = 64 SEQ_LENGTH = 64
@ -123,13 +123,13 @@ def check_comm(size, rank, prev_rank, next_rank, up_group, down_group, logger):
check_forward_backward(tensor, grad, rank, logger) check_forward_backward(tensor, grad, rank, logger)
def run_check(rank, world_size): def run_check(rank, world_size, port):
launch( launch(
config=CONFIG, config=CONFIG,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
host='localhost', host='localhost',
port=29932, port=port,
backend='nccl' backend='nccl'
) )
logger = get_dist_logger() logger = get_dist_logger()
@ -154,7 +154,7 @@ def run_check(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
def test_p2p(): def test_p2p():
world_size = 4 world_size = 4
run_func = partial(run_check, world_size=world_size) run_func = partial(run_check, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -3,25 +3,24 @@ import os.path as osp
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from colossalai.builder.pipeline import build_pipeline_model_from_cfg from colossalai.builder.pipeline import build_pipeline_model_from_cfg
from colossalai.core import global_context from colossalai.core import global_context
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from functools import partial from functools import partial
import model from colossalai.utils import free_port
DIR_PATH = osp.dirname(osp.realpath(__file__)) DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, 'resnet_config.py') CONFIG_PATH = osp.join(DIR_PATH, 'resnet_config.py')
def run_partition(rank, world_size): def run_partition(rank, world_size, port):
launch(config=CONFIG_PATH, launch(config=CONFIG_PATH,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
host='localhost', host='localhost',
port=29933, port=port,
backend='nccl' backend='nccl'
) )
logger = get_dist_logger() logger = get_dist_logger()
@ -40,7 +39,7 @@ def run_partition(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
def test_partition(): def test_partition():
world_size = 4 world_size = 4
run_func = partial(run_partition, world_size=world_size) run_func = partial(run_partition, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -1,26 +1,23 @@
# referenced from Megatron and used to testify communication # referenced from Megatron and used to testify communication
import colossalai
import os import os
import os.path as osp import os.path as osp
from functools import partial
from pathlib import Path
import colossalai
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import model
from colossalai.builder import build_pipeline_model_from_cfg from colossalai.builder import build_pipeline_model_from_cfg
from colossalai.communication import p2p as p2p_communication
from colossalai.communication.utils import send_tensor_meta, recv_tensor_meta
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import print_rank_0, get_current_device, get_dataloader
from colossalai.engine.schedule import PipelineSchedule from colossalai.engine.schedule import PipelineSchedule
from torchvision.datasets import CIFAR10 from colossalai.initialize import launch
from colossalai.utils import free_port, get_dataloader, print_rank_0
from torchvision import transforms from torchvision import transforms
from pathlib import Path from torchvision.datasets import CIFAR10
from functools import partial
import model
BATCH_SIZE = 32 BATCH_SIZE = 32
NUM_MICRO = 8 NUM_MICRO = 8
@ -30,12 +27,12 @@ DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py') CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
def run_schedule(rank, world_size): def run_schedule(rank, world_size, port):
launch(config=CONFIG_PATH, launch(config=CONFIG_PATH,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
host='localhost', host='localhost',
port=29934, port=port,
backend='nccl') backend='nccl')
# build model # build model
@ -86,7 +83,7 @@ def run_schedule(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
def test_pipeline_schedule(): def test_pipeline_schedule():
world_size = 4 world_size = 4
run_func = partial(run_schedule, world_size=world_size) run_func = partial(run_schedule, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -11,7 +11,7 @@ from colossalai.amp.amp_type import AMP_TYPE
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer from colossalai.trainer import Trainer
from colossalai.utils import MultiTimer, get_dataloader from colossalai.utils import MultiTimer, free_port, get_dataloader
from torch.optim import Adam from torch.optim import Adam
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
@ -26,8 +26,8 @@ CONFIG = dict(
fp16=dict(mode=AMP_TYPE.TORCH)) fp16=dict(mode=AMP_TYPE.TORCH))
def run_trainer_no_pipeline(rank, world_size): def run_trainer_no_pipeline(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29930, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# build model # build model
model = resnet18(num_classes=10) model = resnet18(num_classes=10)
@ -88,7 +88,7 @@ def run_trainer_no_pipeline(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
def test_trainer_no_pipeline(): def test_trainer_no_pipeline():
world_size = 4 world_size = 4
run_func = partial(run_trainer_no_pipeline, world_size=world_size) run_func = partial(run_trainer_no_pipeline, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -12,7 +12,7 @@ from colossalai.core import global_context as gpc
from colossalai.engine.schedule import PipelineSchedule from colossalai.engine.schedule import PipelineSchedule
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer from colossalai.trainer import Trainer
from colossalai.utils import MultiTimer, get_dataloader from colossalai.utils import MultiTimer, free_port, get_dataloader
from torch.optim import Adam from torch.optim import Adam
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
@ -25,8 +25,8 @@ NUM_EPOCHS = 200
CONFIG = dict(parallel=dict(pipeline=2, ), ) CONFIG = dict(parallel=dict(pipeline=2, ), )
def run_trainer_with_pipeline(rank, world_size): def run_trainer_with_pipeline(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29931, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# build model # build model
model = resnet18(num_classes=10) model = resnet18(num_classes=10)
@ -99,7 +99,7 @@ def run_trainer_with_pipeline(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
def test_trainer_with_pipeline(): def test_trainer_with_pipeline():
world_size = 4 world_size = 4
run_func = partial(run_trainer_with_pipeline, world_size=world_size) run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -1,21 +1,19 @@
import colossalai
import os import os
from functools import partial
from pathlib import Path
import colossalai
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from functools import partial
from pathlib import Path
from torchvision import transforms
from torch.optim import Adam
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import report_memory_usage, get_dataloader from colossalai.utils import free_port, get_dataloader
from colossalai.initialize import get_default_parser from torch.optim import Adam
from torchvision.models import resnet18 from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
# Config # Config
BATCH_SIZE = 16 BATCH_SIZE = 16
@ -32,7 +30,7 @@ CONFIG = dict(
) )
def run_no_pipeline(rank, world_size): def run_no_pipeline(rank, world_size, port):
# init dist env # init dist env
colossalai.launch( colossalai.launch(
@ -40,7 +38,7 @@ def run_no_pipeline(rank, world_size):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
host='localhost', host='localhost',
port=29500, port=port,
backend='nccl' backend='nccl'
) )
@ -110,7 +108,7 @@ def run_no_pipeline(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
def test_engine(): def test_engine():
world_size = 4 world_size = 4
func = partial(run_no_pipeline, world_size=world_size) func = partial(run_no_pipeline, world_size=world_size, port=free_port())
mp.spawn(func, nprocs=world_size) mp.spawn(func, nprocs=world_size)

View File

@ -2,18 +2,18 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os import os
import pytest from functools import partial
import torch
import torch.multiprocessing as mp
from pathlib import Path from pathlib import Path
import colossalai import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import get_dataloader from colossalai.utils import free_port, get_dataloader
from torchvision import transforms from torchvision import transforms
from torchvision.models import resnet18
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from functools import partial from torchvision.models import resnet18
BATCH_SIZE = 16 BATCH_SIZE = 16
IMG_SIZE = 224 IMG_SIZE = 224
@ -34,12 +34,12 @@ CONFIG = dict(
) )
def run_dist(rank, world_size): def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, colossalai.launch(config=CONFIG,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
host='localhost', host='localhost',
port=29940, port=port,
backend='nccl') backend='nccl')
# build model # build model
@ -94,7 +94,7 @@ def run_dist(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
def test_zero_level_2(): def test_zero_level_2():
world_size = 4 world_size = 4
run_func = partial(run_dist, world_size=world_size) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -2,18 +2,18 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os import os
import pytest from functools import partial
import torch
import torch.multiprocessing as mp
from pathlib import Path from pathlib import Path
import colossalai import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import get_dataloader from colossalai.utils import free_port, get_dataloader
from torchvision import transforms from torchvision import transforms
from torchvision.models import resnet18
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from functools import partial from torchvision.models import resnet18
BATCH_SIZE = 16 BATCH_SIZE = 16
IMG_SIZE = 224 IMG_SIZE = 224
@ -46,12 +46,12 @@ CONFIG = dict(
) )
def run_dist(rank, world_size): def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, colossalai.launch(config=CONFIG,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
host='localhost', host='localhost',
port=29941, port=port,
backend='nccl') backend='nccl')
# build model # build model
@ -106,7 +106,7 @@ def run_dist(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
def test_zero_level_3(): def test_zero_level_3():
world_size = 4 world_size = 4
run_func = partial(run_dist, world_size=world_size) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -13,7 +13,7 @@ import torch.multiprocessing as mp
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss from colossalai.nn import CrossEntropyLoss
from colossalai.utils import get_dataloader from colossalai.utils import free_port, get_dataloader
from model_zoo.vit import vit_lite_depth7_patch4_32 from model_zoo.vit import vit_lite_depth7_patch4_32
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
@ -40,11 +40,11 @@ def train_epoch(engine, train_dataloader):
return avg_loss return avg_loss
def run_2d_parallel_vision_transformer_level_2(rank, world_size): def run_2d_parallel_vision_transformer_level_2(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29950, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# build model # build model
model = vit_lite_depth7_patch4_32(tensor_parallel='2d') model = vit_lite_depth7_patch4_32()
# build dataloader# build dataloaders # build dataloader# build dataloaders
train_dataset = CIFAR10(root=Path(os.environ['DATA']), train_dataset = CIFAR10(root=Path(os.environ['DATA']),
@ -62,7 +62,7 @@ def run_2d_parallel_vision_transformer_level_2(rank, world_size):
# build optimizer and loss # build optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = CrossEntropyLoss(tensor_parallel='2d') criterion = CrossEntropyLoss()
engine, train_dataloader, *args = colossalai.initialize(model=model, engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer, optimizer=optimizer,
@ -90,7 +90,7 @@ def run_2d_parallel_vision_transformer_level_2(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
def test_2d_vit_zero_level_2(): def test_2d_vit_zero_level_2():
world_size = 8 world_size = 8
run_func = partial(run_2d_parallel_vision_transformer_level_2, world_size=world_size) run_func = partial(run_2d_parallel_vision_transformer_level_2, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -13,7 +13,7 @@ import torch.multiprocessing as mp
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss from colossalai.nn import CrossEntropyLoss
from colossalai.utils import get_dataloader from colossalai.utils import free_port, get_dataloader
from model_zoo.vit import vit_lite_depth7_patch4_32 from model_zoo.vit import vit_lite_depth7_patch4_32
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
@ -40,11 +40,11 @@ def train_epoch(engine, train_dataloader):
return avg_loss return avg_loss
def run_2d_parallel_vision_transformer_level_3(rank, world_size): def run_2d_parallel_vision_transformer_level_3(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29951, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# build model # build model
model = vit_lite_depth7_patch4_32(tensor_parallel='2d') model = vit_lite_depth7_patch4_32()
# build dataloader# build dataloaders # build dataloader# build dataloaders
train_dataset = CIFAR10(root=Path(os.environ['DATA']), train_dataset = CIFAR10(root=Path(os.environ['DATA']),
@ -62,7 +62,7 @@ def run_2d_parallel_vision_transformer_level_3(rank, world_size):
# build optimizer and loss # build optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = CrossEntropyLoss(tensor_parallel='2d') criterion = CrossEntropyLoss()
engine, train_dataloader, *args = colossalai.initialize(model=model, engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer, optimizer=optimizer,
@ -91,7 +91,7 @@ def run_2d_parallel_vision_transformer_level_3(rank, world_size):
@pytest.mark.skip("Level 3 has unknown bug so skip this test for now") @pytest.mark.skip("Level 3 has unknown bug so skip this test for now")
def test_3d_vit_zero_level_3(): def test_3d_vit_zero_level_3():
world_size = 8 world_size = 8
run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size) run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)