mirror of https://github.com/hpcaitech/ColossalAI
Hotfix/Colossalai layers (#92)
* optimized 1d layer apis; reorganized nn.layer modules; fixed tests * fixed 2.5d runtime issue * reworked split batch, now called in trainer.schedule.load_batch Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>pull/95/head
parent
0fedef4f3c
commit
01a80cd86d
|
@ -2,7 +2,7 @@ BATCH_SIZE = 512
|
|||
LEARNING_RATE = 2e-3
|
||||
WEIGHT_DECAY = 3e-2
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 4
|
||||
TENSOR_PARALLEL_SIZE = 2
|
||||
TENSOR_PARALLEL_MODE = '1d'
|
||||
|
||||
NUM_EPOCHS = 200
|
||||
|
|
|
@ -72,13 +72,11 @@ def train_cifar():
|
|||
os.mkdir(log_path)
|
||||
logger.log_to_file(log_path)
|
||||
|
||||
tp = gpc.config.parallel.tensor.mode
|
||||
|
||||
model = vit_lite_depth7_patch4_32(tensor_parallel=tp)
|
||||
model = vit_lite_depth7_patch4_32()
|
||||
|
||||
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
|
||||
|
@ -107,7 +105,7 @@ def train_cifar():
|
|||
LogMetricByStepHook(),
|
||||
# LogTimingByEpochHook(timer=timer, logger=logger),
|
||||
# LogMemoryByEpochHook(logger=logger),
|
||||
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
|
||||
AccuracyHook(accuracy_func=Accuracy()),
|
||||
LossHook(),
|
||||
ThroughputHook(),
|
||||
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False)
|
||||
|
|
|
@ -4,7 +4,7 @@ TOTAL_BATCH_SIZE = 4096
|
|||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 4
|
||||
TENSOR_PARALLEL_SIZE = 2
|
||||
TENSOR_PARALLEL_MODE = '1d'
|
||||
|
||||
NUM_EPOCHS = 300
|
||||
|
|
|
@ -159,14 +159,12 @@ def train_imagenet():
|
|||
os.mkdir(log_path)
|
||||
logger.log_to_file(log_path)
|
||||
|
||||
tp = gpc.config.parallel.tensor.mode
|
||||
|
||||
model = vit_small_patch16_224(tensor_parallel=tp, num_classes=100, init_method='jax')
|
||||
model = vit_small_patch16_224(num_classes=100, init_method='jax')
|
||||
|
||||
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
|
||||
|
@ -192,7 +190,7 @@ def train_imagenet():
|
|||
LogMetricByStepHook(),
|
||||
# LogTimingByEpochHook(timer=timer, logger=logger),
|
||||
# LogMemoryByEpochHook(logger=logger),
|
||||
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
|
||||
AccuracyHook(accuracy_func=Accuracy()),
|
||||
LossHook(),
|
||||
ThroughputHook(),
|
||||
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
|
||||
|
|
|
@ -4,7 +4,7 @@ TOTAL_BATCH_SIZE = 4096
|
|||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 4
|
||||
TENSOR_PARALLEL_SIZE = 2
|
||||
TENSOR_PARALLEL_MODE = '1d'
|
||||
|
||||
NUM_EPOCHS = 300
|
||||
|
|
|
@ -159,14 +159,12 @@ def train_imagenet():
|
|||
os.mkdir(log_path)
|
||||
logger.log_to_file(log_path)
|
||||
|
||||
tp = gpc.config.parallel.tensor.mode
|
||||
|
||||
model = vit_small_patch16_224(tensor_parallel=tp, num_classes=1000, init_method='jax')
|
||||
model = vit_small_patch16_224(num_classes=1000, init_method='jax')
|
||||
|
||||
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
|
||||
|
@ -192,7 +190,7 @@ def train_imagenet():
|
|||
LogMetricByStepHook(),
|
||||
# LogTimingByEpochHook(timer=timer, logger=logger),
|
||||
# LogMemoryByEpochHook(logger=logger),
|
||||
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
|
||||
AccuracyHook(accuracy_func=Accuracy()),
|
||||
LossHook(),
|
||||
ThroughputHook(),
|
||||
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence']
|
||||
TENSOR_PARALLEL_MODE = 'tensor_parallel_mode'
|
||||
|
||||
# intializer
|
||||
INITIALIZER_MAPPING = {
|
||||
|
@ -16,6 +17,9 @@ INITIALIZER_MAPPING = {
|
|||
'sequence': 'Initializer_Sequence'
|
||||
}
|
||||
|
||||
# 1D parallel
|
||||
PARALLEL_INPUT_1D = 'parallel_input_1d'
|
||||
|
||||
# 2D paralllel
|
||||
SUMMA_DIM = 'SUMMA_DIM'
|
||||
|
||||
|
|
|
@ -1,17 +1,18 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
|
||||
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING, TENSOR_PARALLEL_MODE
|
||||
from colossalai.context.config import Config
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
|
||||
from .parallel_mode import ParallelMode
|
||||
from .random import add_seed, get_seeds, set_mode
|
||||
|
||||
|
@ -386,6 +387,7 @@ class ParallelContext:
|
|||
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
|
||||
tensor_parallel_mode = parallel_config['tensor']['mode']
|
||||
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
|
||||
os.environ[TENSOR_PARALLEL_MODE] = str(tensor_parallel_mode)
|
||||
self.check_sanity()
|
||||
|
||||
pg_init = []
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.context import Config
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
from ..parallel_mode import ParallelMode
|
||||
from colossalai.constants import PARALLEL_INPUT_1D
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
|
@ -29,6 +30,7 @@ class Initializer_1D(ProcessGroupInitializer):
|
|||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_1D
|
||||
os.environ[PARALLEL_INPUT_1D] = ''
|
||||
|
||||
for i in range(self.num_group):
|
||||
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
|
||||
|
|
|
@ -10,7 +10,7 @@ from typing import Iterable, Union, List, Callable
|
|||
from .._base_engine import Engine
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from colossalai.nn.layer import split_batch
|
||||
|
||||
class BaseSchedule(ABC):
|
||||
"""A basic helper class to control the process of training or evaluation.
|
||||
|
@ -59,7 +59,11 @@ class BaseSchedule(ABC):
|
|||
else:
|
||||
data, label = batch_data
|
||||
|
||||
data, label = self._to_list(data), self._to_list(label)
|
||||
if isinstance(label, (tuple, list)):
|
||||
self.batch_size = label[0].size(0)
|
||||
else:
|
||||
self.batch_size = label.size(0)
|
||||
data, label = self._to_list(split_batch(data)), self._to_list(split_batch(label))
|
||||
return self._move_to_device(data), self._move_to_device(label)
|
||||
|
||||
def pre_processing(self, engine: Engine):
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
from .colossalai_layer import *
|
||||
from .fused_bias_gelu import bias_gelu_impl
|
||||
from .parallel_1d import *
|
||||
from .parallel_2d import *
|
||||
from .parallel_2p5d import *
|
||||
from .parallel_3d import *
|
||||
from .parallel_sequence import *
|
||||
from .utils import *
|
||||
from .vanilla import *
|
||||
from .wrapper import *
|
||||
|
|
|
@ -1,231 +0,0 @@
|
|||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import dtype, nn
|
||||
from torch.nn.modules.activation import *
|
||||
from torch.nn.modules.adaptive import *
|
||||
from torch.nn.modules.batchnorm import *
|
||||
from torch.nn.modules.channelshuffle import *
|
||||
from torch.nn.modules.conv import *
|
||||
from torch.nn.modules.distance import *
|
||||
from torch.nn.modules.dropout import *
|
||||
from torch.nn.modules.flatten import *
|
||||
from torch.nn.modules.fold import *
|
||||
from torch.nn.modules.instancenorm import *
|
||||
from torch.nn.modules.linear import *
|
||||
from torch.nn.modules.normalization import *
|
||||
from torch.nn.modules.padding import *
|
||||
from torch.nn.modules.pixelshuffle import *
|
||||
from torch.nn.modules.pooling import *
|
||||
from torch.nn.modules.rnn import *
|
||||
from torch.nn.modules.sparse import *
|
||||
from torch.nn.modules.transformer import *
|
||||
from torch.nn.modules.upsampling import *
|
||||
|
||||
from .. import init as init
|
||||
|
||||
from .vanilla import *
|
||||
from .parallel_1d import *
|
||||
from .parallel_2d import *
|
||||
from .parallel_2p5d import *
|
||||
from .parallel_3d import *
|
||||
from .parallel_sequence import *
|
||||
|
||||
_parallel_linear = {'1d_col': Linear1D_Col, '1d_row': Linear1D_Row, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
|
||||
|
||||
_parallel_classifier = {
|
||||
None: VanillaClassifier,
|
||||
'1d': VanillaClassifier,
|
||||
'2d': Classifier2D,
|
||||
'2.5d': Classifier2p5D,
|
||||
'3d': Classifier3D
|
||||
}
|
||||
|
||||
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
|
||||
|
||||
_parallel_embedding = {'3d': Embedding3D}
|
||||
|
||||
_parallel_patchembedding = {
|
||||
None: VanillaPatchEmbedding,
|
||||
'1d': VanillaPatchEmbedding,
|
||||
'2d': PatchEmbedding2D,
|
||||
'2.5d': PatchEmbedding2p5D,
|
||||
'3d': PatchEmbedding3D
|
||||
}
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
tensor_parallel: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
if tensor_parallel is None:
|
||||
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
|
||||
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
|
||||
if bias:
|
||||
bias_initializer(self.layer.bias, fan_in=in_features)
|
||||
else:
|
||||
self.layer = _parallel_linear[tensor_parallel](
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.layer.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.layer.bias
|
||||
|
||||
def forward(self, *args):
|
||||
return self.layer(*args)
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None, tensor_parallel: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
if tensor_parallel in [None, '1d']:
|
||||
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
|
||||
else:
|
||||
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.norm.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.norm.bias
|
||||
|
||||
def forward(self, *args):
|
||||
return self.norm(*args)
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
tensor_parallel: Optional[str] = None,
|
||||
*args,
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
if tensor_parallel in [None, '1d']:
|
||||
self.embed = nn.Embedding(num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
device=get_current_device(),
|
||||
dtype=dtype,
|
||||
*args,
|
||||
**kwargs)
|
||||
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
|
||||
else:
|
||||
self.embed = _parallel_embedding[tensor_parallel](
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.embed.weight
|
||||
|
||||
def forward(self, *args):
|
||||
return self.embed(*args)
|
||||
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
dtype: dtype = None,
|
||||
flatten: bool = True,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer: Callable = init.zeros_(),
|
||||
tensor_parallel: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
self.embed = _parallel_patchembedding[tensor_parallel](
|
||||
img_size,
|
||||
patch_size,
|
||||
in_chans,
|
||||
embed_size,
|
||||
dtype=dtype,
|
||||
flatten=flatten,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
position_embed_initializer=position_embed_initializer,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.embed.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.embed.bias
|
||||
|
||||
@property
|
||||
def pos_embed(self):
|
||||
return self.embed.pos_embed
|
||||
|
||||
@property
|
||||
def cls_token(self):
|
||||
return self.embed.cls_token
|
||||
|
||||
def forward(self, *args):
|
||||
return self.embed(*args)
|
||||
|
||||
|
||||
class Classifier(nn.Module):
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: nn.Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
tensor_parallel: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
self.layer = _parallel_classifier[tensor_parallel](
|
||||
in_features,
|
||||
num_classes,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.layer.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.layer.bias
|
||||
|
||||
def forward(self, *args):
|
||||
return self.layer(*args)
|
|
@ -0,0 +1,7 @@
|
|||
from ._utils import split_batch
|
||||
from .dropout import Dropout
|
||||
from .embedding import Embedding, PatchEmbedding
|
||||
from .linear import Classifier, Linear
|
||||
from .normalization import LayerNorm
|
||||
|
||||
__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'split_batch']
|
|
@ -0,0 +1,19 @@
|
|||
from torch import Tensor
|
||||
|
||||
from ..parallel_2d._operation import split_tensor_2d
|
||||
from ..parallel_2p5d._operation import split_tensor_2p5d
|
||||
from ..parallel_3d._operation import split_tensor_3d
|
||||
from ..utils import get_tensor_parallel_mode
|
||||
|
||||
_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_tensor_3d}
|
||||
|
||||
|
||||
def split_batch(input_) -> Tensor:
|
||||
tensor_parallel_mode = get_tensor_parallel_mode()
|
||||
if tensor_parallel_mode in _parallel_split_batch:
|
||||
if isinstance(input_, (tuple, list)):
|
||||
return tuple(map(_parallel_split_batch[tensor_parallel_mode], input_))
|
||||
else:
|
||||
return _parallel_split_batch[tensor_parallel_mode](input_)
|
||||
else:
|
||||
return input_
|
|
@ -0,0 +1,23 @@
|
|||
from contextlib import nullcontext
|
||||
|
||||
import torch.nn as nn
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.utils import conditional_context
|
||||
|
||||
from ..parallel_1d import *
|
||||
from ..utils import get_tensor_parallel_mode
|
||||
|
||||
|
||||
class Dropout(nn.Module):
|
||||
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.tensor_parallel = get_tensor_parallel_mode()
|
||||
if self.tensor_parallel == '1d':
|
||||
self.drop = Dropout1D(p, inplace)
|
||||
else:
|
||||
self.drop = nn.Dropout(p, inplace)
|
||||
|
||||
def forward(self, *args):
|
||||
cm = nullcontext() if self.tensor_parallel in ['None', '1d'] else seed(ParallelMode.TENSOR)
|
||||
with cm:
|
||||
return self.drop(*args)
|
|
@ -0,0 +1,107 @@
|
|||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import dtype, nn
|
||||
|
||||
from ... import init as init
|
||||
from ..parallel_1d import *
|
||||
from ..parallel_2d import *
|
||||
from ..parallel_2p5d import *
|
||||
from ..parallel_3d import *
|
||||
from ..utils import get_tensor_parallel_mode
|
||||
from ..vanilla import *
|
||||
|
||||
_parallel_embedding = {'1d': Embedding1D, '2d': Embedding2D, '2.5d': Embedding2p5D, '3d': Embedding3D}
|
||||
|
||||
_parallel_patchembedding = {
|
||||
'None': VanillaPatchEmbedding,
|
||||
'1d': VanillaPatchEmbedding,
|
||||
'2d': PatchEmbedding2D,
|
||||
'2.5d': PatchEmbedding2p5D,
|
||||
'3d': PatchEmbedding3D
|
||||
}
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
if tensor_parallel == 'None':
|
||||
self.embed = nn.Embedding(num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
device=get_current_device(),
|
||||
dtype=dtype,
|
||||
*args,
|
||||
**kwargs)
|
||||
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
|
||||
else:
|
||||
self.embed = _parallel_embedding[tensor_parallel](
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.embed.weight
|
||||
|
||||
def forward(self, *args):
|
||||
return self.embed(*args)
|
||||
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
dtype: dtype = None,
|
||||
flatten: bool = True,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer: Callable = init.zeros_()) -> None:
|
||||
super().__init__()
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
self.embed = _parallel_patchembedding[tensor_parallel](
|
||||
img_size,
|
||||
patch_size,
|
||||
in_chans,
|
||||
embed_size,
|
||||
dtype=dtype,
|
||||
flatten=flatten,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
position_embed_initializer=position_embed_initializer,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.embed.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.embed.bias
|
||||
|
||||
@property
|
||||
def pos_embed(self):
|
||||
return self.embed.pos_embed
|
||||
|
||||
@property
|
||||
def cls_token(self):
|
||||
return self.embed.cls_token
|
||||
|
||||
def forward(self, *args):
|
||||
return self.embed(*args)
|
|
@ -0,0 +1,97 @@
|
|||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
from colossalai.nn.layer.parallel_1d.layers import Classifier1D
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import dtype, nn
|
||||
|
||||
from ... import init as init
|
||||
from ..parallel_1d import *
|
||||
from ..parallel_2d import *
|
||||
from ..parallel_2p5d import *
|
||||
from ..parallel_3d import *
|
||||
from ..utils import get_tensor_parallel_mode
|
||||
from ..vanilla import *
|
||||
|
||||
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
|
||||
|
||||
_parallel_classifier = {
|
||||
'None': VanillaClassifier,
|
||||
'1d': Classifier1D,
|
||||
'2d': Classifier2D,
|
||||
'2.5d': Classifier2p5D,
|
||||
'3d': Classifier3D
|
||||
}
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
if tensor_parallel == 'None':
|
||||
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
|
||||
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
|
||||
if bias:
|
||||
bias_initializer(self.layer.bias, fan_in=in_features)
|
||||
else:
|
||||
self.layer = _parallel_linear[tensor_parallel](
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.layer.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.layer.bias
|
||||
|
||||
def forward(self, *args):
|
||||
return self.layer(*args)
|
||||
|
||||
|
||||
class Classifier(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: nn.Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer = _parallel_classifier[get_tensor_parallel_mode()](
|
||||
in_features,
|
||||
num_classes,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.layer.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.layer.bias
|
||||
|
||||
def forward(self, *args):
|
||||
return self.layer(*args)
|
|
@ -0,0 +1,35 @@
|
|||
from typing import Optional
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import nn
|
||||
|
||||
from ... import init as init
|
||||
from ..parallel_1d import *
|
||||
from ..parallel_2d import *
|
||||
from ..parallel_2p5d import *
|
||||
from ..parallel_3d import *
|
||||
from ..utils import get_tensor_parallel_mode
|
||||
from ..vanilla import *
|
||||
|
||||
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
|
||||
super().__init__()
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
if tensor_parallel in ['None', '1d']:
|
||||
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
|
||||
else:
|
||||
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.norm.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.norm.bias
|
||||
|
||||
def forward(self, *args):
|
||||
return self.norm(*args)
|
|
@ -1,35 +0,0 @@
|
|||
# adapted from Megatron-LM
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/megatron/model/fused_bias_gelu.py
|
||||
|
||||
import torch
|
||||
|
||||
@torch.jit.script
|
||||
def bias_gelu(bias, y):
|
||||
x = bias + y
|
||||
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
||||
|
||||
# gradient of tanh approximation of gelu
|
||||
# gradient of actual gelu is:
|
||||
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
||||
@torch.jit.script
|
||||
def bias_gelu_back(g, bias, y):
|
||||
x = bias + y
|
||||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
||||
return ff*g
|
||||
|
||||
class GeLUFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
def forward(ctx, input, bias):
|
||||
ctx.save_for_backward(input, bias)
|
||||
return bias_gelu(bias, input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, bias = ctx.saved_tensors
|
||||
tmp = bias_gelu_back(grad_output, bias, input)
|
||||
return tmp, tmp
|
||||
|
||||
bias_gelu_impl = GeLUFunction.apply
|
|
@ -1,4 +1,4 @@
|
|||
from .layers import Linear1D_Col, Linear1D_Row
|
||||
from .layers import Dropout1D, Embedding1D, Linear1D, Linear1D_Col, Linear1D_Row
|
||||
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
|
||||
|
||||
__all__ = ['Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D']
|
||||
__all__ = ['Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D', 'Embedding1D', 'Dropout1D']
|
||||
|
|
|
@ -1,12 +1,21 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.constants import PARALLEL_INPUT_1D
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
from .._common_utils import divide
|
||||
from ..utils import divide
|
||||
|
||||
|
||||
def set_parallel_input(input_parallel: bool):
|
||||
os.environ[PARALLEL_INPUT_1D] = 'true' if input_parallel else ''
|
||||
|
||||
|
||||
def get_parallel_input():
|
||||
return bool(os.environ[PARALLEL_INPUT_1D])
|
||||
|
||||
|
||||
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):
|
||||
|
|
|
@ -3,10 +3,10 @@
|
|||
|
||||
import math
|
||||
import numbers
|
||||
from contextlib import nullcontext
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from colossalai.communication import broadcast
|
||||
from colossalai.context import ParallelMode, seed
|
||||
|
@ -14,13 +14,122 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.nn import init as init
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import Tensor
|
||||
from torch import Tensor, dtype
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition
|
||||
from ._operation import FusedLayerNormAffineFunction1D
|
||||
from ._utils import (gather_forward_split_backward, reduce_grad, reduce_input, split_forward_gather_backward)
|
||||
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
|
||||
split_forward_gather_backward)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class Linear1D(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
parallel_input = get_parallel_input()
|
||||
if not parallel_input:
|
||||
self.layer = Linear1D_Col(in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer)
|
||||
else:
|
||||
self.layer = Linear1D_Row(in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
parallel_input=parallel_input,
|
||||
skip_bias_add=skip_bias_add,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.layer.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.layer.bias
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
return self.layer(input_)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class Classifier1D(ParallelLayer):
|
||||
"""RowLinear with given weight"""
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.num_classes = num_classes
|
||||
self.parallel_input = get_parallel_input()
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
self.has_weight = False
|
||||
else:
|
||||
self.weight = Parameter(torch.empty(self.num_classes, self.input_size_per_partition, **factory_kwargs))
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.num_classes, **factory_kwargs))
|
||||
else:
|
||||
self.bias = None
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
set_parallel_input(False)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.num_classes
|
||||
if self.has_weight:
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D)
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
if self.has_weight:
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Set up backprop all-reduce.
|
||||
if self.parallel_input:
|
||||
input_ = input_
|
||||
else:
|
||||
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
||||
|
||||
output = output + self.bias
|
||||
return output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
|
@ -77,6 +186,7 @@ class Linear1D_Col(ParallelLayer):
|
|||
with seed(ParallelMode.TENSOR):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
set_parallel_input(True)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
|
@ -158,6 +268,7 @@ class Linear1D_Row(ParallelLayer):
|
|||
with seed(ParallelMode.TENSOR):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
set_parallel_input(False)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
|
@ -208,3 +319,68 @@ class MixedFusedLayerNorm1D(torch.nn.Module):
|
|||
|
||||
def forward(self, input):
|
||||
return FusedLayerNormAffineFunction1D.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class Embedding1D(ParallelLayer):
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size)
|
||||
|
||||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
set_parallel_input(False)
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size)
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class Dropout1D(ParallelLayer):
|
||||
def __init__(self, p: float = 0.5, inplace: bool = False):
|
||||
super().__init__()
|
||||
self.parallel_input = get_parallel_input()
|
||||
self.p = p
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
cm = nullcontext() if not self.parallel_input else seed(ParallelMode.TENSOR)
|
||||
with cm:
|
||||
output = F.dropout(input_, self.p, self.training, self.inplace)
|
||||
return output
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from ._operation import reduce_by_batch_2d, split_batch_2d
|
||||
from ._operation import reduce_by_batch_2d, split_tensor_2d
|
||||
from .layers import Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D
|
||||
|
||||
__all__ = [
|
||||
'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D'
|
||||
'split_tensor_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D'
|
||||
]
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Any, Optional, Tuple
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.communication.collective import (all_gather, all_reduce, reduce_scatter)
|
||||
from colossalai.communication.collective import (all_gather, all_reduce, reduce, reduce_scatter)
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
|
@ -595,7 +595,9 @@ class SplitFirst(torch.autograd.Function):
|
|||
return grad, None, None
|
||||
|
||||
|
||||
def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
if input_.size(dim) <= 1:
|
||||
return input_
|
||||
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL),
|
||||
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous()
|
||||
|
||||
|
@ -603,17 +605,28 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
|||
class reduce_by_batch_2d(torch.autograd.Function):
|
||||
"""All-reduce the input from the model parallel region."""
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
|
||||
return input_
|
||||
def symbolic(graph, input_, reduce_mean: bool = False):
|
||||
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
|
||||
if reduce_mean:
|
||||
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
|
||||
return output / reduce_size
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, input_):
|
||||
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
|
||||
return input_.clone()
|
||||
def forward(ctx, input_, reduce_mean: bool = False):
|
||||
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
|
||||
ctx.reduce_mean = reduce_mean
|
||||
if reduce_mean:
|
||||
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
|
||||
ctx.reduce_size = reduce_size
|
||||
return output.clone() / reduce_size
|
||||
return output.clone()
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
def backward(ctx, output_grad):
|
||||
if ctx.reduce_mean:
|
||||
return output_grad / ctx.reduce_size, None
|
||||
else:
|
||||
return output_grad, None
|
||||
|
|
|
@ -13,9 +13,9 @@ from colossalai.utils import get_current_device
|
|||
from torch import Tensor, dtype
|
||||
from torch.nn import Parameter
|
||||
|
||||
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
from ..base_layer import ParallelLayer
|
||||
from ._operation import (Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d, split_batch_2d)
|
||||
from ._operation import Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d
|
||||
from ._utils import assert_summa_initialization, get_summa_dim_from_env
|
||||
|
||||
|
||||
|
@ -257,8 +257,6 @@ class PatchEmbedding2D(ParallelLayer):
|
|||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
|
||||
input_ = split_batch_2d(input_)
|
||||
|
||||
weight = all_gather_weight_2d.apply(self.weight, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||
bias = all_gather_weight_2d.apply(self.bias, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
|
@ -318,8 +316,6 @@ class Embedding2D(ParallelLayer):
|
|||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_batch_2d(input_)
|
||||
|
||||
weight = all_gather_weight_2d.apply(self.weight, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from ._operation import reduce_by_batch_2p5d, split_batch_2p5d
|
||||
from ._operation import reduce_by_batch_2p5d, split_tensor_2p5d
|
||||
from .layers import Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D
|
||||
|
||||
__all__ = [
|
||||
'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
|
||||
'split_tensor_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
|
||||
'Embedding2p5D'
|
||||
]
|
||||
|
|
|
@ -22,7 +22,7 @@ def get_parallel_rank(parallel_mode: ParallelMode):
|
|||
return gpc.get_local_rank(parallel_mode)
|
||||
|
||||
|
||||
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
|
||||
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
|
||||
|
||||
|
@ -120,30 +120,53 @@ class Matmul_AB_2p5D(torch.autograd.Function):
|
|||
ctx.save_for_backward(A, B)
|
||||
|
||||
A_shape = A.shape
|
||||
A = A.reshape((-1, A_shape[-1])).contiguous()
|
||||
A = A.reshape((-1, A_shape[-1]))
|
||||
B_shape = B.shape
|
||||
B = B.reshape((-1, B_shape[-1])).contiguous()
|
||||
B = B.reshape((-1, B_shape[-1]))
|
||||
C_shape = (A.shape[0], B.shape[-1])
|
||||
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
|
||||
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode) - 1)]
|
||||
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode) - 1)]
|
||||
A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
|
||||
B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
|
||||
op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
|
||||
op_a.wait()
|
||||
op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True)
|
||||
for op in [op_a, op_b]:
|
||||
op.wait()
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
A_list = [torch.empty_like(A) for _ in range(2)]
|
||||
B_list = [torch.empty_like(B) for _ in range(2)]
|
||||
|
||||
row_group = gpc.get_group(row_parallel_mode)
|
||||
col_group = gpc.get_group(col_parallel_mode)
|
||||
|
||||
src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
|
||||
opa = [None] * 2
|
||||
opb = [None] * 2
|
||||
|
||||
A_list[0].copy_(A)
|
||||
B_list[0].copy_(B)
|
||||
opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True)
|
||||
opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True)
|
||||
cur = 0
|
||||
|
||||
for i in range(tesseract_dim):
|
||||
src_a = i + tesseract_dim * row_rank
|
||||
src_b = i + tesseract_dim * col_rank
|
||||
src_a = src_a % tesseract_dim
|
||||
src_b = src_b % tesseract_dim
|
||||
A_temp = A_list[src_a]
|
||||
B_temp = B_list[src_b]
|
||||
torch.addmm(C, A_temp, B_temp, out=C)
|
||||
if i != tesseract_dim - 1:
|
||||
A_list[1 - cur].copy_(A)
|
||||
opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True)
|
||||
B_list[1 - cur].copy_(B)
|
||||
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
|
||||
src=src_b + tesseract_dim,
|
||||
group=col_group,
|
||||
async_op=True)
|
||||
|
||||
if opa[cur] is not None:
|
||||
opa[cur].wait()
|
||||
if opb[cur] is not None:
|
||||
opb[cur].wait()
|
||||
|
||||
torch.addmm(C, A_list[cur], B_list[cur], out=C)
|
||||
cur = 1 - cur
|
||||
src_a += 1
|
||||
src_b += tesseract_dim
|
||||
out = C.reshape(out_shape)
|
||||
|
||||
if ctx:
|
||||
|
@ -201,20 +224,55 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
|
|||
C_shape = (A.shape[0], B.shape[0])
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
|
||||
for i in range(tesseract_dim):
|
||||
B_temp = B.clone()
|
||||
src_b = col_rank + i * tesseract_dim + dep_rank * (
|
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.broadcast(B_temp, src=src_b, group=gpc.get_group(col_parallel_mode))
|
||||
C_temp = torch.matmul(A, B_temp.transpose(0, 1))
|
||||
src_c = i + row_rank * tesseract_dim + dep_rank * (
|
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.reduce(C_temp, dst=src_c, group=gpc.get_group(row_parallel_mode))
|
||||
if i == col_rank:
|
||||
C = C_temp.clone()
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
B_list = [torch.empty_like(B) for _ in range(2)]
|
||||
C_list = [torch.empty_like(C) for _ in range(2)]
|
||||
|
||||
row_group = gpc.get_group(row_parallel_mode)
|
||||
col_group = gpc.get_group(col_parallel_mode)
|
||||
|
||||
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
src_c = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
|
||||
opb = [None] * 2
|
||||
opr = [None] * 2
|
||||
|
||||
B_list[0].copy_(B)
|
||||
opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True)
|
||||
cur = 0
|
||||
|
||||
for i in range(tesseract_dim):
|
||||
if i != tesseract_dim - 1:
|
||||
B_list[1 - cur].copy_(B)
|
||||
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
|
||||
src=src_b + tesseract_dim,
|
||||
group=col_group,
|
||||
async_op=True)
|
||||
|
||||
if opr[cur] is not None:
|
||||
opr[cur].wait()
|
||||
if i - 2 == col_rank:
|
||||
C.copy_(C_list[cur])
|
||||
|
||||
if opb[cur] is not None:
|
||||
opb[cur].wait()
|
||||
|
||||
torch.matmul(A, B_list[cur].transpose(0, 1), out=C_list[cur])
|
||||
opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=row_group, async_op=True)
|
||||
cur = 1 - cur
|
||||
src_b += tesseract_dim
|
||||
src_c += 1
|
||||
|
||||
for op in opr:
|
||||
op.wait()
|
||||
|
||||
if tesseract_dim - 2 == col_rank:
|
||||
C.copy_(C_list[cur])
|
||||
if tesseract_dim - 1 == col_rank:
|
||||
C.copy_(C_list[1 - cur])
|
||||
out = C.reshape(out_shape)
|
||||
|
||||
if ctx:
|
||||
|
@ -272,20 +330,52 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
|
|||
C_shape = (A.shape[-1], B.shape[-1])
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
|
||||
for i in range(tesseract_dim):
|
||||
A_temp = A.clone()
|
||||
src_a = i + row_rank * tesseract_dim + dep_rank * (
|
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.broadcast(A_temp, src=src_a, group=get_parallel_group(row_parallel_mode))
|
||||
C_temp = torch.matmul(A_temp.transpose(0, 1), B)
|
||||
src_c = col_rank + i * tesseract_dim + dep_rank * (
|
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.reduce(C_temp, dst=src_c, group=get_parallel_group(col_parallel_mode))
|
||||
if i == row_rank:
|
||||
C = C_temp.clone()
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
A_list = [torch.empty_like(A) for _ in range(2)]
|
||||
C_list = [torch.empty_like(C) for _ in range(2)]
|
||||
|
||||
row_group = gpc.get_group(row_parallel_mode)
|
||||
col_group = gpc.get_group(col_parallel_mode)
|
||||
|
||||
src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
src_c = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
|
||||
opa = [None] * 2
|
||||
opr = [None] * 2
|
||||
|
||||
A_list[0].copy_(A)
|
||||
opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True)
|
||||
cur = 0
|
||||
|
||||
for i in range(tesseract_dim):
|
||||
if i != tesseract_dim - 1:
|
||||
A_list[1 - cur].copy_(A)
|
||||
opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True)
|
||||
|
||||
if opr[cur] is not None:
|
||||
opr[cur].wait()
|
||||
if i - 2 == row_rank:
|
||||
C.copy_(C_list[cur])
|
||||
|
||||
if opa[cur] is not None:
|
||||
opa[cur].wait()
|
||||
|
||||
torch.matmul(A_list[cur].transpose(0, 1), B, out=C_list[cur])
|
||||
opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=col_group, async_op=True)
|
||||
cur = 1 - cur
|
||||
src_a += 1
|
||||
src_c += tesseract_dim
|
||||
|
||||
for op in opr:
|
||||
op.wait()
|
||||
|
||||
if tesseract_dim - 2 == row_rank:
|
||||
C.copy_(C_list[cur])
|
||||
if tesseract_dim - 1 == row_rank:
|
||||
C.copy_(C_list[1 - cur])
|
||||
out = C.reshape(out_shape)
|
||||
|
||||
if ctx:
|
||||
|
@ -333,8 +423,7 @@ class Add_Bias_2p5D(torch.autograd.Function):
|
|||
bias_temp = bias.clone()
|
||||
else:
|
||||
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device())
|
||||
src_rank = col_rank + dep_rank * (
|
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
src_rank = col_rank + dep_rank * tesseract_dim ** 2 + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode))
|
||||
|
||||
|
@ -469,7 +558,9 @@ class SplitFirst(torch.autograd.Function):
|
|||
return grad, None, None
|
||||
|
||||
|
||||
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
if input_.size(dim) <= 1:
|
||||
return input_
|
||||
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
|
||||
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
|
||||
|
||||
|
@ -477,17 +568,28 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
|||
class reduce_by_batch_2p5d(torch.autograd.Function):
|
||||
"""All-reduce the input from the model parallel region."""
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL))
|
||||
return input_
|
||||
def symbolic(graph, input_, reduce_mean: bool = False):
|
||||
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
|
||||
if reduce_mean:
|
||||
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
|
||||
return output / reduce_size
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, input_):
|
||||
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL))
|
||||
return input_.clone()
|
||||
def forward(ctx, input_, reduce_mean: bool = False):
|
||||
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
|
||||
ctx.reduce_mean = reduce_mean
|
||||
if reduce_mean:
|
||||
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
|
||||
ctx.reduce_size = reduce_size
|
||||
return output.clone() / reduce_size
|
||||
return output.clone()
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
def backward(ctx, output_grad):
|
||||
if ctx.reduce_mean:
|
||||
return output_grad / ctx.reduce_size, None
|
||||
else:
|
||||
return output_grad, None
|
||||
|
|
|
@ -13,10 +13,9 @@ from colossalai.utils import get_current_device
|
|||
from torch import Tensor, dtype
|
||||
from torch.nn import Parameter
|
||||
|
||||
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
|
||||
from ..base_layer import ParallelLayer
|
||||
from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d,
|
||||
split_batch_2p5d)
|
||||
from ..utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
|
||||
from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d)
|
||||
from ._utils import (assert_tesseract_initialization, get_tesseract_dim_dep_from_env)
|
||||
|
||||
|
||||
|
@ -231,7 +230,7 @@ class PatchEmbedding2p5D(ParallelLayer):
|
|||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
self.embed_size = embed_size
|
||||
self.embed_size_per_partition = embed_size // (self.tesseract_dep * self.tesseract_dim**2)
|
||||
self.embed_size_per_partition = embed_size // self.tesseract_dim**2
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.weight = Parameter(
|
||||
|
@ -251,10 +250,10 @@ class PatchEmbedding2p5D(ParallelLayer):
|
|||
self._set_tensor_parallel_attribute()
|
||||
|
||||
def _set_tensor_parallel_attribute(self):
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dep * self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dep * self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dep * self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dim**2)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer):
|
||||
with seed(ParallelMode.TENSOR):
|
||||
|
@ -269,8 +268,6 @@ class PatchEmbedding2p5D(ParallelLayer):
|
|||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
|
||||
input_ = split_batch_2p5d(input_)
|
||||
|
||||
weight = all_gather_weight_2p5d.apply(self.weight, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||
bias = all_gather_weight_2p5d.apply(self.bias, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||
|
||||
|
@ -303,7 +300,7 @@ class Embedding2p5D(ParallelLayer):
|
|||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
embed_dim_per_partition = embedding_dim // (self.tesseract_dep * self.tesseract_dim**2)
|
||||
embed_dim_per_partition = embedding_dim // self.tesseract_dim**2
|
||||
|
||||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
|
@ -316,7 +313,7 @@ class Embedding2p5D(ParallelLayer):
|
|||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
|
@ -330,8 +327,6 @@ class Embedding2p5D(ParallelLayer):
|
|||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_batch_2p5d(input_)
|
||||
|
||||
weight = all_gather_weight_2p5d.apply(self.weight, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||
|
||||
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
@ -359,7 +354,7 @@ class Classifier2p5D(ParallelLayer):
|
|||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||
|
||||
# partitioning dimension
|
||||
self.input_size_per_partition = divide(self.in_features, self.tesseract_dep * self.tesseract_dim**2)
|
||||
self.input_size_per_partition = divide(self.in_features, self.tesseract_dim**2)
|
||||
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
|
@ -378,7 +373,7 @@ class Classifier2p5D(ParallelLayer):
|
|||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
if self.has_weight:
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from ._operation import reduce_by_batch_3d, split_batch_3d
|
||||
from ._operation import reduce_by_batch_3d, split_tensor_3d
|
||||
from .layers import Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D
|
||||
|
||||
__all__ = [
|
||||
'reduce_by_batch_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D'
|
||||
'reduce_by_batch_3d', 'split_tensor_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D'
|
||||
]
|
||||
|
|
|
@ -175,10 +175,12 @@ class layernorm_3d(torch.autograd.Function):
|
|||
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
||||
|
||||
|
||||
def split_batch_3d(input_: Tensor,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
dim: int = 0) -> Tensor:
|
||||
def split_tensor_3d(input_: Tensor,
|
||||
dim: int = 0,
|
||||
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
|
||||
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
|
||||
if input_.size(dim) <= 1:
|
||||
return input_
|
||||
output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode),
|
||||
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
|
||||
output = torch.chunk(output, gpc.get_world_size(input_parallel_mode),
|
||||
|
@ -189,15 +191,27 @@ def split_batch_3d(input_: Tensor,
|
|||
class reduce_by_batch_3d(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode) -> Tensor:
|
||||
def forward(ctx,
|
||||
input_: Tensor,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
reduce_mean: bool = False) -> Tensor:
|
||||
output = all_reduce(input_, input_parallel_mode)
|
||||
output = all_reduce(output, weight_parallel_mode)
|
||||
ctx.reduce_mean = reduce_mean
|
||||
if reduce_mean:
|
||||
reduce_size = gpc.get_world_size(input_parallel_mode) * gpc.get_world_size(weight_parallel_mode)
|
||||
ctx.reduce_size = reduce_size
|
||||
return output.clone() / reduce_size
|
||||
return output.clone()
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
return output_grad, None, None
|
||||
if ctx.reduce_mean:
|
||||
return output_grad / ctx.reduce_size, None, None, None
|
||||
else:
|
||||
return output_grad, None, None, None
|
||||
|
||||
|
||||
class broadcast_weight_3d_from_diagonal(torch.autograd.Function):
|
||||
|
|
|
@ -17,9 +17,9 @@ from colossalai.utils import get_current_device
|
|||
from torch import Tensor, dtype
|
||||
from torch.nn import Parameter
|
||||
|
||||
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
from ._operation import *
|
||||
from ._utils import (get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group)
|
||||
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
|
@ -241,8 +241,6 @@ class PatchEmbedding3D(ParallelLayer):
|
|||
self.pos_embed.register_hook(self._sync_grad_hook)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode)
|
||||
|
||||
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
|
||||
self.weight_parallel_mode, self.output_parallel_mode)
|
||||
output = F.conv2d(input_, weight, self.bias, stride=self.patch_size)
|
||||
|
@ -302,8 +300,6 @@ class Embedding3D(ParallelLayer):
|
|||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode)
|
||||
|
||||
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
|
||||
self.weight_parallel_mode, self.output_parallel_mode)
|
||||
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
from .common import (ACT2FN, CheckpointModule, _ntuple, divide, get_tensor_parallel_mode,
|
||||
set_tensor_parallel_attribute_by_partition, set_tensor_parallel_attribute_by_size, to_2tuple)
|
||||
|
||||
__all__ = [
|
||||
'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size',
|
||||
'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple'
|
||||
]
|
|
@ -2,11 +2,12 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import collections.abc
|
||||
import os
|
||||
from itertools import repeat
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
|
||||
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_MODE)
|
||||
from colossalai.utils import checkpoint
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
@ -59,6 +60,10 @@ def set_tensor_parallel_attribute_by_partition(param, num_partitions):
|
|||
setattr(param, NUM_PARTITIONS, num_partitions)
|
||||
|
||||
|
||||
def get_tensor_parallel_mode():
|
||||
return os.environ[TENSOR_PARALLEL_MODE]
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
|
||||
|
|
@ -9,7 +9,7 @@ from colossalai.utils import get_current_device
|
|||
from torch import Tensor, dtype
|
||||
from torch import nn as nn
|
||||
|
||||
from .._common_utils import to_2tuple
|
||||
from ..utils import to_2tuple
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
|
|
|
@ -2,6 +2,7 @@ from torch import nn
|
|||
from torch.nn.modules.loss import *
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
from colossalai.nn.layer.utils import get_tensor_parallel_mode
|
||||
from .loss_2d import CrossEntropyLoss2D
|
||||
from .loss_2p5d import CrossEntropyLoss2p5D
|
||||
from .loss_3d import CrossEntropyLoss3D
|
||||
|
@ -14,9 +15,10 @@ _parallel_cross_entropy = {
|
|||
|
||||
|
||||
class CrossEntropyLoss(_Loss):
|
||||
def __init__(self, reduction: bool = True, tensor_parallel: str = None, *args, **kwargs):
|
||||
def __init__(self, reduction: bool = True, *args, **kwargs):
|
||||
super().__init__()
|
||||
if tensor_parallel in [None, '1d']:
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
if tensor_parallel in ['None', '1d']:
|
||||
reduction = 'mean' if reduction else 'none'
|
||||
self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
|
||||
else:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
|
||||
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d
|
||||
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
|
||||
from colossalai.registry import LOSSES
|
||||
from torch.nn.functional import cross_entropy
|
||||
|
@ -20,11 +20,8 @@ class CrossEntropyLoss2D(_Loss):
|
|||
self.loss_kwargs = kwargs
|
||||
|
||||
def forward(self, logits, targets):
|
||||
batch_size = targets.size(0)
|
||||
targets = split_batch_2d(targets)
|
||||
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
|
||||
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
|
||||
if self.reduction_mean:
|
||||
loss = loss.sum()
|
||||
loss = reduce_by_batch_2d.apply(loss)
|
||||
loss /= batch_size
|
||||
loss = loss.mean()
|
||||
loss = reduce_by_batch_2d.apply(loss, True)
|
||||
return loss
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
|
||||
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d
|
||||
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
|
||||
from colossalai.registry import LOSSES
|
||||
from torch.nn.functional import cross_entropy
|
||||
|
@ -19,11 +19,8 @@ class CrossEntropyLoss2p5D(_Loss):
|
|||
self.loss_kwargs = kwargs
|
||||
|
||||
def forward(self, logits, targets):
|
||||
batch_size = targets.size(0)
|
||||
targets = split_batch_2p5d(targets)
|
||||
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
|
||||
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
|
||||
if self.reduction_mean:
|
||||
loss = loss.sum()
|
||||
loss = reduce_by_batch_2p5d.apply(loss)
|
||||
loss /= batch_size
|
||||
loss = loss.mean()
|
||||
loss = reduce_by_batch_2p5d.apply(loss, True)
|
||||
return loss
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_batch_3d
|
||||
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d
|
||||
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
|
||||
from colossalai.registry import LOSSES
|
||||
from torch.nn.functional import cross_entropy
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
class CrossEntropyLoss3D(_Loss):
|
||||
"""Cross entropy loss for 3D parallelism
|
||||
|
@ -28,11 +27,8 @@ class CrossEntropyLoss3D(_Loss):
|
|||
self.loss_kwargs = kwargs
|
||||
|
||||
def forward(self, logits, targets):
|
||||
batch_size = targets.size(0)
|
||||
targets = split_batch_3d(targets, self.input_parallel_mode, self.weight_parallel_mode)
|
||||
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
|
||||
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
|
||||
if self.reduction_mean:
|
||||
loss = loss.sum()
|
||||
loss = reduce_by_batch_3d.apply(loss, self.input_parallel_mode, self.weight_parallel_mode)
|
||||
loss /= batch_size
|
||||
loss = loss.mean()
|
||||
loss = reduce_by_batch_3d.apply(loss, self.input_parallel_mode, self.weight_parallel_mode, True)
|
||||
return loss
|
||||
|
|
|
@ -4,6 +4,7 @@ from ._utils import calc_acc
|
|||
from .accuracy_2d import Accuracy2D
|
||||
from .accuracy_2p5d import Accuracy2p5D
|
||||
from .accuracy_3d import Accuracy3D
|
||||
from colossalai.nn.layer.utils import get_tensor_parallel_mode
|
||||
|
||||
_parallel_accuracy = {
|
||||
'2d': Accuracy2D,
|
||||
|
@ -13,9 +14,10 @@ _parallel_accuracy = {
|
|||
|
||||
|
||||
class Accuracy(nn.Module):
|
||||
def __init__(self, tensor_parallel: str = None):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if tensor_parallel in [None, '1d']:
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
if tensor_parallel in ['None', '1d']:
|
||||
self.acc = calc_acc
|
||||
else:
|
||||
self.acc = _parallel_accuracy[tensor_parallel]()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
|
||||
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d
|
||||
from torch import nn
|
||||
|
||||
from ._utils import calc_acc
|
||||
|
@ -11,7 +11,6 @@ class Accuracy2D(nn.Module):
|
|||
|
||||
def forward(self, logits, targets):
|
||||
with torch.no_grad():
|
||||
targets = split_batch_2d(targets)
|
||||
correct = calc_acc(logits, targets)
|
||||
correct = reduce_by_batch_2d.apply(correct)
|
||||
return correct
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
|
||||
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d
|
||||
from torch import nn
|
||||
|
||||
from ._utils import calc_acc
|
||||
|
@ -11,7 +11,6 @@ class Accuracy2p5D(nn.Module):
|
|||
|
||||
def forward(self, logits, targets):
|
||||
with torch.no_grad():
|
||||
targets = split_batch_2p5d(targets)
|
||||
correct = calc_acc(logits, targets)
|
||||
correct = reduce_by_batch_2p5d.apply(correct)
|
||||
return correct
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_batch_3d
|
||||
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d
|
||||
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
|
||||
from torch import nn
|
||||
|
||||
|
@ -15,7 +15,6 @@ class Accuracy3D(nn.Module):
|
|||
|
||||
def forward(self, logits, targets):
|
||||
with torch.no_grad():
|
||||
targets = split_batch_3d(targets, self.input_parallel_mode, self.weight_parallel_mode)
|
||||
correct = calc_acc(logits, targets)
|
||||
correct = reduce_by_batch_3d.apply(correct, self.input_parallel_mode, self.weight_parallel_mode)
|
||||
return correct
|
||||
|
|
|
@ -173,7 +173,7 @@ class AccuracyMetric(Metric):
|
|||
self.accumulated_sum.zero_()
|
||||
self.accumulated_correct.zero_()
|
||||
|
||||
def update(self, logits, targets) -> None:
|
||||
def update(self, logits, targets, batch_size) -> None:
|
||||
"""Updates last step accuracy and accumulated accuracy with current logits
|
||||
and labels. It expects the output has logits and labels.
|
||||
|
||||
|
@ -187,7 +187,7 @@ class AccuracyMetric(Metric):
|
|||
# update
|
||||
correct = self.acc(logits, targets)
|
||||
|
||||
self.last_step_sum.fill_(targets.size(0))
|
||||
self.last_step_sum.fill_(batch_size)
|
||||
self.last_step_correct.fill_(correct)
|
||||
self.accumulated_sum += self.last_step_sum
|
||||
self.accumulated_correct += self.last_step_correct
|
||||
|
@ -296,7 +296,8 @@ class AccuracyHook(MetricHook):
|
|||
|
||||
def after_test_iter(self, trainer, logits, targets, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(logits, targets)
|
||||
batch_size = trainer.schedule.batch_size
|
||||
self.metric.update(logits, targets, batch_size)
|
||||
|
||||
|
||||
class ThroughputMetric(Metric):
|
||||
|
@ -313,10 +314,8 @@ class ThroughputMetric(Metric):
|
|||
self.last_step_num_samples.zero_()
|
||||
self.last_step_used_time.zero_()
|
||||
|
||||
def update(self, tensor, time) -> None:
|
||||
if isinstance(tensor, (list, tuple)):
|
||||
tensor = tensor[0]
|
||||
self.last_step_num_samples.fill_(tensor.size(0))
|
||||
def update(self, num_samples, time) -> None:
|
||||
self.last_step_num_samples.fill_(num_samples)
|
||||
self.last_step_used_time.fill_(time)
|
||||
self.accumulated_num_samples += self.last_step_num_samples
|
||||
self.accumulated_used_time += self.last_step_used_time
|
||||
|
@ -354,11 +353,11 @@ class ThroughputHook(MetricHook):
|
|||
def before_train_epoch(self, trainer):
|
||||
self.metric.reset()
|
||||
|
||||
def after_train_iter(self, trainer, logits, targets, *args):
|
||||
self.metric.update(targets, trainer._timer.get_timer('Train-step').get_elapsed_time())
|
||||
def after_train_iter(self, trainer, *args):
|
||||
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
|
||||
|
||||
def before_test(self, trainer):
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, trainer, logits, targets, *args):
|
||||
self.metric.update(targets, trainer._timer.get_timer('Test-step').get_elapsed_time())
|
||||
def after_test_iter(self, trainer, *args):
|
||||
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())
|
||||
|
|
|
@ -1,27 +1,19 @@
|
|||
from .activation_checkpoint import checkpoint
|
||||
from .common import (print_rank_0, sync_model_param_in_dp, is_dp_rank_0,
|
||||
is_tp_rank_0, is_no_pp_or_last_stage, is_using_ddp,
|
||||
is_using_pp, conditional_context, is_model_parallel_parameter,
|
||||
clip_grad_norm_fp32, count_zeros_fp32, copy_tensor_parallel_attributes,
|
||||
param_is_not_tensor_parallel_duplicate, switch_virtual_pipeline_parallel_rank)
|
||||
from .cuda import get_current_device, synchronize, empty_cache, set_to_cuda
|
||||
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
|
||||
free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0,
|
||||
is_using_ddp, is_using_pp, multi_tensor_applier, param_is_not_tensor_parallel_duplicate,
|
||||
print_rank_0, switch_virtual_pipeline_parallel_rank, sync_model_param_in_dp)
|
||||
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
|
||||
from .data_sampler import DataParallelSampler, get_dataloader
|
||||
from .gradient_accumulation import accumulate_gradient
|
||||
from .memory import report_memory_usage
|
||||
from .timer import MultiTimer, Timer
|
||||
from .multi_tensor_apply import multi_tensor_applier
|
||||
from .gradient_accumulation import accumulate_gradient
|
||||
from .data_sampler import DataParallelSampler, get_dataloader
|
||||
|
||||
__all__ = ['checkpoint',
|
||||
'print_rank_0', 'sync_model_param_in_dp', 'is_dp_rank_0',
|
||||
'is_tp_rank_0', 'is_no_pp_or_last_stage', 'is_using_ddp',
|
||||
'is_using_pp', 'conditional_context', 'is_model_parallel_parameter',
|
||||
__all__ = [
|
||||
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param_in_dp', 'is_dp_rank_0', 'is_tp_rank_0',
|
||||
'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'conditional_context', 'is_model_parallel_parameter',
|
||||
'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
|
||||
'param_is_not_tensor_parallel_duplicate',
|
||||
'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
||||
'report_memory_usage',
|
||||
'Timer', 'MultiTimer',
|
||||
'multi_tensor_applier',
|
||||
'accumulate_gradient',
|
||||
'DataParallelSampler', 'get_dataloader',
|
||||
'switch_virtual_pipeline_parallel_rank'
|
||||
]
|
||||
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
||||
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
|
||||
'get_dataloader', 'switch_virtual_pipeline_parallel_rank'
|
||||
]
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
import random
|
||||
import socket
|
||||
|
||||
import torch
|
||||
from torch._six import inf
|
||||
|
@ -9,16 +11,15 @@ try:
|
|||
except:
|
||||
pass
|
||||
|
||||
import torch.distributed as dist
|
||||
from contextlib import contextmanager
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from .multi_tensor_apply import multi_tensor_applier
|
||||
from colossalai.constants import IS_TENSOR_PARALLEL, TENSOR_PARALLEL_ATTRIBUTES, NUM_PARTITIONS
|
||||
|
||||
import torch.distributed as dist
|
||||
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
from .multi_tensor_apply import multi_tensor_applier
|
||||
|
||||
|
||||
def print_rank_0(msg: str, logger=None):
|
||||
'''Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
|
||||
|
@ -33,6 +34,18 @@ def print_rank_0(msg: str, logger=None):
|
|||
logger.info(msg)
|
||||
|
||||
|
||||
def free_port():
|
||||
while True:
|
||||
try:
|
||||
sock = socket.socket()
|
||||
port = random.randint(20000, 65000)
|
||||
sock.bind(('localhost', port))
|
||||
sock.close()
|
||||
return port
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
def sync_model_param_in_dp(model):
|
||||
'''Make sure data parameters are consistent during Data Parallel Mode
|
||||
|
||||
|
|
|
@ -3,9 +3,8 @@ from typing import Callable
|
|||
|
||||
import torch
|
||||
from colossalai import nn as col_nn
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.nn.layer.utils import CheckpointModule
|
||||
from colossalai.registry import LAYERS, MODELS
|
||||
from colossalai.utils import checkpoint
|
||||
from torch import dtype, nn
|
||||
|
||||
__all__ = [
|
||||
|
@ -72,8 +71,7 @@ class ViTEmbedding(nn.Module):
|
|||
dropout: float,
|
||||
dtype: dtype = None,
|
||||
flatten: bool = True,
|
||||
init_method: str = 'torch',
|
||||
tensor_parallel: str = None):
|
||||
init_method: str = 'torch'):
|
||||
super().__init__()
|
||||
self.patch_embed = col_nn.PatchEmbedding(img_size,
|
||||
patch_size,
|
||||
|
@ -81,19 +79,17 @@ class ViTEmbedding(nn.Module):
|
|||
embedding_dim,
|
||||
dtype=dtype,
|
||||
flatten=flatten,
|
||||
tensor_parallel=tensor_parallel,
|
||||
**_init_rules[init_method]['embed'])
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.dropout = col_nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTSelfAttention(nn.Module):
|
||||
class ViTSelfAttention(CheckpointModule):
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
|
@ -102,27 +98,17 @@ class ViTSelfAttention(nn.Module):
|
|||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch',
|
||||
tensor_parallel: str = None):
|
||||
super().__init__()
|
||||
init_method: str = 'torch'):
|
||||
super().__init__(checkpoint)
|
||||
self.attention_head_size = dim // num_heads
|
||||
self.checkpoint = checkpoint
|
||||
self.tensor_parallel = tensor_parallel
|
||||
|
||||
self.query_key_value = col_nn.Linear(dim,
|
||||
3 * dim,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
tensor_parallel='1d_col' if tensor_parallel == '1d' else tensor_parallel,
|
||||
**_init_rules[init_method]['transformer'])
|
||||
self.attention_dropout = nn.Dropout(attention_dropout)
|
||||
self.dense = col_nn.Linear(dim,
|
||||
dim,
|
||||
dtype=dtype,
|
||||
bias=True,
|
||||
tensor_parallel='1d_row' if tensor_parallel == '1d' else tensor_parallel,
|
||||
**_init_rules[init_method]['transformer'])
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.attention_dropout = col_nn.Dropout(attention_dropout)
|
||||
self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True, **_init_rules[init_method]['transformer'])
|
||||
self.dropout = col_nn.Dropout(dropout)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def _forward(self, x):
|
||||
|
@ -138,7 +124,6 @@ class ViTSelfAttention(nn.Module):
|
|||
x = torch.matmul(q, k.transpose(-1, -2))
|
||||
x = x / math.sqrt(self.attention_head_size)
|
||||
x = self.softmax(x)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.attention_dropout(x)
|
||||
|
||||
x = torch.matmul(x, v)
|
||||
|
@ -147,26 +132,13 @@ class ViTSelfAttention(nn.Module):
|
|||
x = x.reshape(new_context_layer_shape)
|
||||
|
||||
x = self.dense(x)
|
||||
if self.tensor_parallel == '1d':
|
||||
x = self.dropout(x)
|
||||
else:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.dropout(x)
|
||||
|
||||
return x
|
||||
|
||||
def _checkpoint_forward(self, x):
|
||||
return checkpoint(self._forward, x)
|
||||
|
||||
def forward(self, x):
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(x)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTMLP(nn.Module):
|
||||
class ViTMLP(CheckpointModule):
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
mlp_ratio: int,
|
||||
|
@ -175,50 +147,30 @@ class ViTMLP(nn.Module):
|
|||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch',
|
||||
tensor_parallel: str = None):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.tensor_parallel = tensor_parallel
|
||||
|
||||
init_method: str = 'torch'):
|
||||
super().__init__(checkpoint)
|
||||
self.dense_1 = col_nn.Linear(dim,
|
||||
mlp_ratio * dim,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
tensor_parallel='1d_col' if tensor_parallel == '1d' else tensor_parallel,
|
||||
**_init_rules[init_method]['transformer'])
|
||||
self.activation = activation
|
||||
self.dropout_1 = col_nn.Dropout(dropout)
|
||||
self.dense_2 = col_nn.Linear(mlp_ratio * dim,
|
||||
dim,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
tensor_parallel='1d_row' if tensor_parallel == '1d' else tensor_parallel,
|
||||
**_init_rules[init_method]['transformer'])
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.dropout_2 = col_nn.Dropout(dropout)
|
||||
|
||||
def _forward(self, x):
|
||||
x = self.dense_1(x)
|
||||
x = self.activation(x)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.dropout(x)
|
||||
x = self.dropout_1(x)
|
||||
x = self.dense_2(x)
|
||||
if self.tensor_parallel == '1d':
|
||||
x = self.dropout(x)
|
||||
else:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.dropout_2(x)
|
||||
return x
|
||||
|
||||
def _checkpoint_forward(self, x):
|
||||
return checkpoint(self._forward, x)
|
||||
|
||||
def forward(self, x):
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(x)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTHead(nn.Module):
|
||||
|
@ -228,19 +180,14 @@ class ViTHead(nn.Module):
|
|||
representation_size: int = None,
|
||||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
init_method: str = 'torch',
|
||||
tensor_parallel: str = None):
|
||||
init_method: str = 'torch'):
|
||||
super().__init__()
|
||||
if representation_size:
|
||||
tensor_parallel_kwargs = {'tensor_parallel': '1d_col' if tensor_parallel == '1d' else tensor_parallel}
|
||||
if tensor_parallel == '1d':
|
||||
tensor_parallel_kwargs['gather_output'] = True
|
||||
self.representation = col_nn.Linear(dim,
|
||||
representation_size,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
**_init_rules[init_method]['head'],
|
||||
**tensor_parallel_kwargs)
|
||||
**_init_rules[init_method]['head'])
|
||||
else:
|
||||
self.representation = None
|
||||
representation_size = dim
|
||||
|
@ -249,7 +196,6 @@ class ViTHead(nn.Module):
|
|||
num_classes,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
tensor_parallel=tensor_parallel,
|
||||
**_init_rules[init_method]['head'])
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -273,10 +219,9 @@ class ViTBlock(nn.Module):
|
|||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch',
|
||||
tensor_parallel: str = None):
|
||||
init_method: str = 'torch'):
|
||||
super().__init__()
|
||||
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype, tensor_parallel=tensor_parallel)
|
||||
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
|
||||
self.attn = ViTSelfAttention(dim=dim,
|
||||
num_heads=num_heads,
|
||||
attention_dropout=attention_dropout,
|
||||
|
@ -284,10 +229,9 @@ class ViTBlock(nn.Module):
|
|||
bias=bias,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
init_method=init_method,
|
||||
tensor_parallel=tensor_parallel)
|
||||
init_method=init_method)
|
||||
self.drop_path = col_nn.DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype, tensor_parallel=tensor_parallel)
|
||||
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
|
||||
self.mlp = ViTMLP(dim=dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
activation=activation,
|
||||
|
@ -295,8 +239,7 @@ class ViTBlock(nn.Module):
|
|||
dtype=dtype,
|
||||
bias=bias,
|
||||
checkpoint=checkpoint,
|
||||
init_method=init_method,
|
||||
tensor_parallel=tensor_parallel)
|
||||
init_method=init_method)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
|
@ -323,20 +266,16 @@ class VisionTransformer(nn.Module):
|
|||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch',
|
||||
tensor_parallel: str = None):
|
||||
init_method: str = 'torch'):
|
||||
super().__init__()
|
||||
|
||||
embed = ViTEmbedding(
|
||||
img_size=img_size,
|
||||
embed = ViTEmbedding(img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embedding_dim=dim,
|
||||
dropout=dropout,
|
||||
dtype=dtype,
|
||||
init_method=init_method,
|
||||
tensor_parallel=tensor_parallel,
|
||||
)
|
||||
init_method=init_method)
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
|
||||
|
@ -353,26 +292,17 @@ class VisionTransformer(nn.Module):
|
|||
bias=bias,
|
||||
checkpoint=checkpoint,
|
||||
init_method=init_method,
|
||||
tensor_parallel=tensor_parallel,
|
||||
) for i in range(depth)
|
||||
]
|
||||
|
||||
norm = col_nn.LayerNorm(
|
||||
normalized_shape=dim,
|
||||
eps=1e-6,
|
||||
dtype=dtype,
|
||||
tensor_parallel=tensor_parallel,
|
||||
)
|
||||
norm = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
|
||||
|
||||
head = ViTHead(
|
||||
dim=dim,
|
||||
head = ViTHead(dim=dim,
|
||||
num_classes=num_classes,
|
||||
representation_size=representation_size,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
init_method=init_method,
|
||||
tensor_parallel=tensor_parallel,
|
||||
)
|
||||
init_method=init_method)
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
embed,
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import time
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
|
@ -9,7 +8,7 @@ from colossalai.communication import all_gather, all_reduce, reduce_scatter
|
|||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
|
||||
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
|
||||
|
||||
|
@ -49,8 +48,8 @@ def check_all_reduce():
|
|||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def check_layer(rank, world_size):
|
||||
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=30010, backend='nccl')
|
||||
def check_layer(rank, world_size, port):
|
||||
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
assert dist.get_rank() == gpc.get_global_rank()
|
||||
print('Rank {} / {}'.format(dist.get_rank(), dist.get_world_size()))
|
||||
|
@ -66,7 +65,7 @@ def check_layer(rank, world_size):
|
|||
@pytest.mark.dist
|
||||
def test_comm():
|
||||
world_size = 4
|
||||
run_func = partial(check_layer, world_size=world_size)
|
||||
run_func = partial(check_layer, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai import launch
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from colossalai.utils import free_port
|
||||
|
||||
CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2d_init.py').absolute()
|
||||
|
||||
|
@ -87,7 +88,7 @@ def test_2d_init():
|
|||
test_fn = partial(init_2d,
|
||||
world_size=world_size,
|
||||
backend='gloo',
|
||||
port='29900',
|
||||
port=free_port(),
|
||||
host='localhost'
|
||||
)
|
||||
mp.spawn(test_fn, nprocs=world_size)
|
||||
|
|
|
@ -7,10 +7,10 @@ from pathlib import Path
|
|||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port
|
||||
|
||||
CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2p5d_init.py').absolute()
|
||||
|
||||
|
@ -111,7 +111,7 @@ def test_2halfd_init():
|
|||
test_fn = partial(init_2halfd,
|
||||
world_size=world_size,
|
||||
backend='gloo',
|
||||
port='29901',
|
||||
port=free_port(),
|
||||
host='localhost'
|
||||
)
|
||||
mp.spawn(test_fn, nprocs=world_size)
|
||||
|
|
|
@ -7,11 +7,10 @@ from pathlib import Path
|
|||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port
|
||||
|
||||
CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_3d_init.py').absolute()
|
||||
|
||||
|
@ -104,7 +103,7 @@ def test_3d_init():
|
|||
test_fn = partial(init_3d,
|
||||
world_size=world_size,
|
||||
backend='gloo',
|
||||
port='29902',
|
||||
port=free_port(),
|
||||
host='localhost'
|
||||
)
|
||||
mp.spawn(test_fn, nprocs=world_size)
|
||||
|
|
|
@ -13,7 +13,7 @@ from colossalai.logging import get_dist_logger
|
|||
from colossalai.nn import Accuracy, LinearWarmupLR
|
||||
from colossalai.nn.loss import CrossEntropyLoss
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.utils import MultiTimer, get_dataloader
|
||||
from colossalai.utils import MultiTimer, free_port, get_dataloader
|
||||
from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep
|
||||
from model_zoo.vit import vit_tiny_patch4_32
|
||||
from torchvision import transforms
|
||||
|
@ -27,12 +27,12 @@ CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
|
|||
gradient_accumulation=2)
|
||||
|
||||
|
||||
def run_trainer(rank, world_size):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=30000, backend='nccl')
|
||||
def run_trainer(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
model = vit_tiny_patch4_32(tensor_parallel='1d')
|
||||
model = vit_tiny_patch4_32()
|
||||
pipe_model = build_pipeline_model(model.layers, num_chunks=1)
|
||||
|
||||
# build dataloaders
|
||||
|
@ -54,7 +54,7 @@ def run_trainer(rank, world_size):
|
|||
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=BATCH_SIZE, pin_memory=True)
|
||||
|
||||
# build criterion
|
||||
criterion = CrossEntropyLoss(tensor_parallel='1d')
|
||||
criterion = CrossEntropyLoss()
|
||||
|
||||
# optimizer
|
||||
optimizer = torch.optim.Adam(pipe_model.parameters(), lr=0.001, weight_decay=0)
|
||||
|
@ -78,7 +78,6 @@ def run_trainer(rank, world_size):
|
|||
hook_list = [
|
||||
hooks.LossHook(),
|
||||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
|
||||
hooks.AccuracyHook(accuracy_func=Accuracy(tensor_parallel='1d')),
|
||||
hooks.LogMetricByEpochHook(logger),
|
||||
]
|
||||
|
||||
|
@ -95,7 +94,7 @@ def run_trainer(rank, world_size):
|
|||
# @pytest.mark.skip("This test requires more than 8 GPUs, you should invoke this test script using test.sh provided manually")
|
||||
def test_hybrid_parallel():
|
||||
world_size = 8
|
||||
run_func = partial(run_trainer, world_size=world_size)
|
||||
run_func = partial(run_trainer, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -1,25 +1,23 @@
|
|||
# !/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import colossalai
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import os.path as osp
|
||||
from pathlib import Path
|
||||
import torch.nn as nn
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from torchvision import transforms
|
||||
from torch.optim import Adam
|
||||
from colossalai.core import global_context as gpc
|
||||
import torch.nn as nn
|
||||
from colossalai.amp import AMP_TYPE
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import report_memory_usage, get_dataloader
|
||||
from torchvision.models import resnet18
|
||||
from colossalai.utils import free_port, get_dataloader, report_memory_usage
|
||||
from torch.optim import Adam
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from functools import partial
|
||||
|
||||
from torchvision.models import resnet18
|
||||
|
||||
# Config
|
||||
BATCH_SIZE = 128
|
||||
|
@ -38,14 +36,14 @@ CONFIG = dict(
|
|||
)
|
||||
|
||||
|
||||
def run_engine(rank, world_size):
|
||||
def run_engine(rank, world_size, port):
|
||||
# init dist env
|
||||
colossalai.launch(
|
||||
config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=29910,
|
||||
port=port,
|
||||
backend='nccl'
|
||||
)
|
||||
|
||||
|
@ -104,7 +102,7 @@ def run_engine(rank, world_size):
|
|||
@pytest.mark.dist
|
||||
def test_engine():
|
||||
world_size = 4
|
||||
run_func = partial(run_engine, world_size=world_size)
|
||||
run_func = partial(run_engine, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -1,23 +1,20 @@
|
|||
import colossalai
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import os.path as osp
|
||||
from pathlib import Path
|
||||
import torch.nn as nn
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from torchvision import transforms
|
||||
from torch.optim import Adam
|
||||
from colossalai.core import global_context as gpc
|
||||
import torch.nn as nn
|
||||
from colossalai.amp import AMP_TYPE
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import report_memory_usage, get_dataloader
|
||||
from colossalai.initialize import get_default_parser
|
||||
from torchvision.models import resnet18
|
||||
from colossalai.utils import free_port, get_dataloader, report_memory_usage
|
||||
from torch.optim import Adam
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from functools import partial
|
||||
|
||||
from torchvision.models import resnet18
|
||||
|
||||
# Config
|
||||
BATCH_SIZE = 128
|
||||
|
@ -38,14 +35,14 @@ CONFIG = dict(
|
|||
)
|
||||
|
||||
|
||||
def run_engine(rank, world_size):
|
||||
def run_engine(rank, world_size, port):
|
||||
# init dist env
|
||||
colossalai.launch(
|
||||
config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=29911,
|
||||
port=port,
|
||||
backend='nccl'
|
||||
)
|
||||
|
||||
|
@ -104,7 +101,7 @@ def run_engine(rank, world_size):
|
|||
@pytest.mark.dist
|
||||
def test_engine():
|
||||
world_size = 4
|
||||
run_func = partial(run_engine, world_size=world_size)
|
||||
run_func = partial(run_engine, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -1,23 +1,19 @@
|
|||
import colossalai
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import os.path as osp
|
||||
from pathlib import Path
|
||||
import torch.nn as nn
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from torchvision import transforms
|
||||
from torch.optim import Adam
|
||||
import torch.nn as nn
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.amp import AMP_TYPE
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import report_memory_usage, get_dataloader
|
||||
from colossalai.initialize import get_default_parser
|
||||
from torchvision.models import resnet18
|
||||
from colossalai.utils import free_port, get_dataloader, report_memory_usage
|
||||
from torch.optim import Adam
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from functools import partial
|
||||
|
||||
from torchvision.models import resnet18
|
||||
|
||||
# Config
|
||||
BATCH_SIZE = 128
|
||||
|
@ -35,14 +31,14 @@ CONFIG = dict(
|
|||
)
|
||||
|
||||
|
||||
def run_engine(rank, world_size):
|
||||
def run_engine(rank, world_size, port):
|
||||
# init dist env
|
||||
colossalai.launch(
|
||||
config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=29912,
|
||||
port=port,
|
||||
backend='nccl'
|
||||
)
|
||||
|
||||
|
@ -101,7 +97,7 @@ def run_engine(rank, world_size):
|
|||
@pytest.mark.dist
|
||||
def test_engine():
|
||||
world_size = 4
|
||||
run_func = partial(run_engine, world_size=world_size)
|
||||
run_func = partial(run_engine, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -1,23 +1,20 @@
|
|||
import colossalai
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import os.path as osp
|
||||
from pathlib import Path
|
||||
import torch.nn as nn
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from torchvision import transforms
|
||||
from torch.optim import Adam
|
||||
from colossalai.core import global_context as gpc
|
||||
import torch.nn as nn
|
||||
from colossalai.amp import AMP_TYPE
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import report_memory_usage, get_dataloader
|
||||
from colossalai.initialize import get_default_parser
|
||||
from torchvision.models import resnet18
|
||||
from colossalai.utils import free_port, get_dataloader, report_memory_usage
|
||||
from torch.optim import Adam
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from functools import partial
|
||||
|
||||
from torchvision.models import resnet18
|
||||
|
||||
# Config
|
||||
BATCH_SIZE = 128
|
||||
|
@ -36,14 +33,14 @@ CONFIG = dict(
|
|||
)
|
||||
|
||||
|
||||
def run_engine(rank, world_size):
|
||||
def run_engine(rank, world_size, port):
|
||||
# init dist env
|
||||
colossalai.launch(
|
||||
config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=29913,
|
||||
port=port,
|
||||
backend='nccl'
|
||||
)
|
||||
|
||||
|
@ -102,7 +99,7 @@ def run_engine(rank, world_size):
|
|||
@pytest.mark.dist
|
||||
def test_engine():
|
||||
world_size = 4
|
||||
run_func = partial(run_engine, world_size=world_size)
|
||||
run_func = partial(run_engine, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from functools import partial
|
||||
from colossalai.utils import free_port
|
||||
|
||||
from checks_1d.check_layer_1d import *
|
||||
|
||||
CONFIG = dict(
|
||||
|
@ -21,12 +23,12 @@ CONFIG = dict(
|
|||
)
|
||||
|
||||
|
||||
def check_layer(rank, world_size):
|
||||
def check_layer(rank, world_size, port):
|
||||
launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=29920,
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
check_linear_col()
|
||||
|
@ -39,7 +41,7 @@ def check_layer(rank, world_size):
|
|||
@pytest.mark.dist
|
||||
def test_1d():
|
||||
world_size = 4
|
||||
run_func = partial(check_layer, world_size=world_size)
|
||||
run_func = partial(check_layer, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port
|
||||
|
||||
from checks_2d.check_layer_2d import *
|
||||
from checks_2d.check_operation_2d import *
|
||||
from functools import partial
|
||||
|
||||
|
||||
CONFIG = dict(
|
||||
parallel=dict(
|
||||
|
@ -34,12 +35,12 @@ def check_layer():
|
|||
check_layernorm()
|
||||
check_classifier()
|
||||
|
||||
def check_layer_and_operation(rank, world_size):
|
||||
def check_layer_and_operation(rank, world_size, port):
|
||||
launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=29921,
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
# check_operations()
|
||||
|
@ -51,7 +52,7 @@ def check_layer_and_operation(rank, world_size):
|
|||
@pytest.mark.dist
|
||||
def test_2d():
|
||||
world_size = 4
|
||||
run_func = partial(check_layer_and_operation, world_size=world_size)
|
||||
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from checks_2p5d.check_layer_2p5d import check_linear, check_layernorm, check_classifier
|
||||
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
|
||||
from functools import partial
|
||||
from colossalai.utils import free_port
|
||||
|
||||
from checks_2p5d.check_layer_2p5d import (check_classifier, check_layernorm,
|
||||
check_linear)
|
||||
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
|
||||
|
||||
CONFIG = dict(
|
||||
parallel=dict(
|
||||
|
@ -29,12 +31,12 @@ def check_layer():
|
|||
check_classifier()
|
||||
|
||||
|
||||
def check_layer_and_operation(rank, world_size):
|
||||
def check_layer_and_operation(rank, world_size, port):
|
||||
launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=29922,
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
check_operations()
|
||||
|
@ -46,7 +48,7 @@ def check_layer_and_operation(rank, world_size):
|
|||
@pytest.mark.dist
|
||||
def test_2p5d():
|
||||
world_size = 4
|
||||
run_func = partial(check_layer_and_operation, world_size=world_size)
|
||||
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import torch
|
|||
import torch.multiprocessing as mp
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port
|
||||
|
||||
from checks_3d.check_layer_3d import *
|
||||
|
||||
|
@ -27,8 +28,8 @@ def check_layer():
|
|||
# check_loss()
|
||||
|
||||
|
||||
def check_layer_and_operation(rank, world_size):
|
||||
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29923, backend='nccl')
|
||||
def check_layer_and_operation(rank, world_size, port):
|
||||
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
check_layer()
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -37,7 +38,7 @@ def check_layer_and_operation(rank, world_size):
|
|||
@pytest.mark.dist
|
||||
def test_3d():
|
||||
world_size = 8
|
||||
run_func = partial(check_layer_and_operation, world_size=world_size)
|
||||
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -4,10 +4,11 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.initialize import launch, get_default_parser
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import get_dist_logger
|
||||
from checks_seq.check_layer_seq import *
|
||||
from functools import partial
|
||||
from colossalai.utils import free_port
|
||||
|
||||
|
||||
CONFIG = dict(
|
||||
|
@ -22,13 +23,13 @@ def check_layer():
|
|||
check_selfattention()
|
||||
|
||||
|
||||
def run_check_sequence(rank, world_size):
|
||||
def run_check_sequence(rank, world_size, port):
|
||||
# init dist
|
||||
launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=29924,
|
||||
port=port,
|
||||
backend='nccl')
|
||||
logger = get_dist_logger()
|
||||
logger.info('Distributed environment is initialzied.', ranks=[0])
|
||||
|
@ -41,7 +42,7 @@ def run_check_sequence(rank, world_size):
|
|||
@pytest.mark.dist
|
||||
def test_sequence():
|
||||
world_size = 4
|
||||
run_func = partial(run_check_sequence, world_size=world_size)
|
||||
run_func = partial(run_check_sequence, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import model
|
||||
from pathlib import Path
|
||||
|
||||
BATCH_SIZE = 128
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.communication import (recv_backward, recv_forward,
|
||||
recv_tensor_meta, send_backward,
|
||||
send_backward_recv_forward, send_forward,
|
||||
|
@ -15,8 +16,7 @@ from colossalai.context.parallel_mode import ParallelMode
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
from functools import partial
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
|
||||
BATCH_SIZE = 16
|
||||
SEQ_LENGTH = 64
|
||||
|
@ -123,13 +123,13 @@ def check_comm(size, rank, prev_rank, next_rank, up_group, down_group, logger):
|
|||
check_forward_backward(tensor, grad, rank, logger)
|
||||
|
||||
|
||||
def run_check(rank, world_size):
|
||||
def run_check(rank, world_size, port):
|
||||
launch(
|
||||
config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=29932,
|
||||
port=port,
|
||||
backend='nccl'
|
||||
)
|
||||
logger = get_dist_logger()
|
||||
|
@ -154,7 +154,7 @@ def run_check(rank, world_size):
|
|||
@pytest.mark.dist
|
||||
def test_p2p():
|
||||
world_size = 4
|
||||
run_func = partial(run_check, world_size=world_size)
|
||||
run_func = partial(run_check, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -3,25 +3,24 @@ import os.path as osp
|
|||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.builder.pipeline import build_pipeline_model_from_cfg
|
||||
from colossalai.core import global_context
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import get_dist_logger
|
||||
from functools import partial
|
||||
import model
|
||||
from colossalai.utils import free_port
|
||||
|
||||
DIR_PATH = osp.dirname(osp.realpath(__file__))
|
||||
CONFIG_PATH = osp.join(DIR_PATH, 'resnet_config.py')
|
||||
|
||||
|
||||
def run_partition(rank, world_size):
|
||||
def run_partition(rank, world_size, port):
|
||||
launch(config=CONFIG_PATH,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=29933,
|
||||
port=port,
|
||||
backend='nccl'
|
||||
)
|
||||
logger = get_dist_logger()
|
||||
|
@ -40,7 +39,7 @@ def run_partition(rank, world_size):
|
|||
@pytest.mark.dist
|
||||
def test_partition():
|
||||
world_size = 4
|
||||
run_func = partial(run_partition, world_size=world_size)
|
||||
run_func = partial(run_partition, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -1,26 +1,23 @@
|
|||
# referenced from Megatron and used to testify communication
|
||||
|
||||
import colossalai
|
||||
import os
|
||||
import os.path as osp
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import model
|
||||
|
||||
from colossalai.builder import build_pipeline_model_from_cfg
|
||||
from colossalai.communication import p2p as p2p_communication
|
||||
from colossalai.communication.utils import send_tensor_meta, recv_tensor_meta
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import print_rank_0, get_current_device, get_dataloader
|
||||
from colossalai.engine.schedule import PipelineSchedule
|
||||
from torchvision.datasets import CIFAR10
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port, get_dataloader, print_rank_0
|
||||
from torchvision import transforms
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
||||
import model
|
||||
|
||||
BATCH_SIZE = 32
|
||||
NUM_MICRO = 8
|
||||
|
@ -30,12 +27,12 @@ DIR_PATH = osp.dirname(osp.realpath(__file__))
|
|||
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
|
||||
|
||||
|
||||
def run_schedule(rank, world_size):
|
||||
def run_schedule(rank, world_size, port):
|
||||
launch(config=CONFIG_PATH,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=29934,
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
# build model
|
||||
|
@ -86,7 +83,7 @@ def run_schedule(rank, world_size):
|
|||
@pytest.mark.dist
|
||||
def test_pipeline_schedule():
|
||||
world_size = 4
|
||||
run_func = partial(run_schedule, world_size=world_size)
|
||||
run_func = partial(run_schedule, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ from colossalai.amp.amp_type import AMP_TYPE
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.trainer import Trainer
|
||||
from colossalai.utils import MultiTimer, get_dataloader
|
||||
from colossalai.utils import MultiTimer, free_port, get_dataloader
|
||||
from torch.optim import Adam
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
@ -26,8 +26,8 @@ CONFIG = dict(
|
|||
fp16=dict(mode=AMP_TYPE.TORCH))
|
||||
|
||||
|
||||
def run_trainer_no_pipeline(rank, world_size):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29930, backend='nccl')
|
||||
def run_trainer_no_pipeline(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# build model
|
||||
model = resnet18(num_classes=10)
|
||||
|
@ -88,7 +88,7 @@ def run_trainer_no_pipeline(rank, world_size):
|
|||
@pytest.mark.dist
|
||||
def test_trainer_no_pipeline():
|
||||
world_size = 4
|
||||
run_func = partial(run_trainer_no_pipeline, world_size=world_size)
|
||||
run_func = partial(run_trainer_no_pipeline, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.engine.schedule import PipelineSchedule
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.trainer import Trainer
|
||||
from colossalai.utils import MultiTimer, get_dataloader
|
||||
from colossalai.utils import MultiTimer, free_port, get_dataloader
|
||||
from torch.optim import Adam
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
@ -25,8 +25,8 @@ NUM_EPOCHS = 200
|
|||
CONFIG = dict(parallel=dict(pipeline=2, ), )
|
||||
|
||||
|
||||
def run_trainer_with_pipeline(rank, world_size):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29931, backend='nccl')
|
||||
def run_trainer_with_pipeline(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# build model
|
||||
model = resnet18(num_classes=10)
|
||||
|
@ -99,7 +99,7 @@ def run_trainer_with_pipeline(rank, world_size):
|
|||
@pytest.mark.dist
|
||||
def test_trainer_with_pipeline():
|
||||
world_size = 4
|
||||
run_func = partial(run_trainer_with_pipeline, world_size=world_size)
|
||||
run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -1,21 +1,19 @@
|
|||
import colossalai
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from torchvision import transforms
|
||||
from torch.optim import Adam
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import report_memory_usage, get_dataloader
|
||||
from colossalai.initialize import get_default_parser
|
||||
from torchvision.models import resnet18
|
||||
from colossalai.utils import free_port, get_dataloader
|
||||
from torch.optim import Adam
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
||||
from torchvision.models import resnet18
|
||||
|
||||
# Config
|
||||
BATCH_SIZE = 16
|
||||
|
@ -32,7 +30,7 @@ CONFIG = dict(
|
|||
)
|
||||
|
||||
|
||||
def run_no_pipeline(rank, world_size):
|
||||
def run_no_pipeline(rank, world_size, port):
|
||||
|
||||
# init dist env
|
||||
colossalai.launch(
|
||||
|
@ -40,7 +38,7 @@ def run_no_pipeline(rank, world_size):
|
|||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=29500,
|
||||
port=port,
|
||||
backend='nccl'
|
||||
)
|
||||
|
||||
|
@ -110,7 +108,7 @@ def run_no_pipeline(rank, world_size):
|
|||
@pytest.mark.dist
|
||||
def test_engine():
|
||||
world_size = 4
|
||||
func = partial(run_no_pipeline, world_size=world_size)
|
||||
func = partial(run_no_pipeline, world_size=world_size, port=free_port())
|
||||
mp.spawn(func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -2,18 +2,18 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_dataloader
|
||||
from colossalai.utils import free_port, get_dataloader
|
||||
from torchvision import transforms
|
||||
from torchvision.models import resnet18
|
||||
from torchvision.datasets import CIFAR10
|
||||
from functools import partial
|
||||
from torchvision.models import resnet18
|
||||
|
||||
BATCH_SIZE = 16
|
||||
IMG_SIZE = 224
|
||||
|
@ -34,12 +34,12 @@ CONFIG = dict(
|
|||
)
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=29940,
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
# build model
|
||||
|
@ -94,7 +94,7 @@ def run_dist(rank, world_size):
|
|||
@pytest.mark.dist
|
||||
def test_zero_level_2():
|
||||
world_size = 4
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -2,18 +2,18 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_dataloader
|
||||
from colossalai.utils import free_port, get_dataloader
|
||||
from torchvision import transforms
|
||||
from torchvision.models import resnet18
|
||||
from torchvision.datasets import CIFAR10
|
||||
from functools import partial
|
||||
from torchvision.models import resnet18
|
||||
|
||||
BATCH_SIZE = 16
|
||||
IMG_SIZE = 224
|
||||
|
@ -46,12 +46,12 @@ CONFIG = dict(
|
|||
)
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=29941,
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
# build model
|
||||
|
@ -106,7 +106,7 @@ def run_dist(rank, world_size):
|
|||
@pytest.mark.dist
|
||||
def test_zero_level_3():
|
||||
world_size = 4
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ import torch.multiprocessing as mp
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import CrossEntropyLoss
|
||||
from colossalai.utils import get_dataloader
|
||||
from colossalai.utils import free_port, get_dataloader
|
||||
from model_zoo.vit import vit_lite_depth7_patch4_32
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
@ -40,11 +40,11 @@ def train_epoch(engine, train_dataloader):
|
|||
return avg_loss
|
||||
|
||||
|
||||
def run_2d_parallel_vision_transformer_level_2(rank, world_size):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29950, backend='nccl')
|
||||
def run_2d_parallel_vision_transformer_level_2(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# build model
|
||||
model = vit_lite_depth7_patch4_32(tensor_parallel='2d')
|
||||
model = vit_lite_depth7_patch4_32()
|
||||
|
||||
# build dataloader# build dataloaders
|
||||
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
|
||||
|
@ -62,7 +62,7 @@ def run_2d_parallel_vision_transformer_level_2(rank, world_size):
|
|||
|
||||
# build optimizer and loss
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = CrossEntropyLoss(tensor_parallel='2d')
|
||||
criterion = CrossEntropyLoss()
|
||||
|
||||
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
|
@ -90,7 +90,7 @@ def run_2d_parallel_vision_transformer_level_2(rank, world_size):
|
|||
@pytest.mark.dist
|
||||
def test_2d_vit_zero_level_2():
|
||||
world_size = 8
|
||||
run_func = partial(run_2d_parallel_vision_transformer_level_2, world_size=world_size)
|
||||
run_func = partial(run_2d_parallel_vision_transformer_level_2, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ import torch.multiprocessing as mp
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import CrossEntropyLoss
|
||||
from colossalai.utils import get_dataloader
|
||||
from colossalai.utils import free_port, get_dataloader
|
||||
from model_zoo.vit import vit_lite_depth7_patch4_32
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
@ -40,11 +40,11 @@ def train_epoch(engine, train_dataloader):
|
|||
return avg_loss
|
||||
|
||||
|
||||
def run_2d_parallel_vision_transformer_level_3(rank, world_size):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29951, backend='nccl')
|
||||
def run_2d_parallel_vision_transformer_level_3(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# build model
|
||||
model = vit_lite_depth7_patch4_32(tensor_parallel='2d')
|
||||
model = vit_lite_depth7_patch4_32()
|
||||
|
||||
# build dataloader# build dataloaders
|
||||
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
|
||||
|
@ -62,7 +62,7 @@ def run_2d_parallel_vision_transformer_level_3(rank, world_size):
|
|||
|
||||
# build optimizer and loss
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = CrossEntropyLoss(tensor_parallel='2d')
|
||||
criterion = CrossEntropyLoss()
|
||||
|
||||
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
|
@ -91,7 +91,7 @@ def run_2d_parallel_vision_transformer_level_3(rank, world_size):
|
|||
@pytest.mark.skip("Level 3 has unknown bug so skip this test for now")
|
||||
def test_3d_vit_zero_level_3():
|
||||
world_size = 8
|
||||
run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size)
|
||||
run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue