mirror of https://github.com/hpcaitech/ColossalAI
Hotfix/Colossalai layers (#92)
* optimized 1d layer apis; reorganized nn.layer modules; fixed tests * fixed 2.5d runtime issue * reworked split batch, now called in trainer.schedule.load_batch Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>pull/95/head
parent
0fedef4f3c
commit
01a80cd86d
|
@ -2,7 +2,7 @@ BATCH_SIZE = 512
|
||||||
LEARNING_RATE = 2e-3
|
LEARNING_RATE = 2e-3
|
||||||
WEIGHT_DECAY = 3e-2
|
WEIGHT_DECAY = 3e-2
|
||||||
|
|
||||||
TENSOR_PARALLEL_SIZE = 4
|
TENSOR_PARALLEL_SIZE = 2
|
||||||
TENSOR_PARALLEL_MODE = '1d'
|
TENSOR_PARALLEL_MODE = '1d'
|
||||||
|
|
||||||
NUM_EPOCHS = 200
|
NUM_EPOCHS = 200
|
||||||
|
|
|
@ -72,13 +72,11 @@ def train_cifar():
|
||||||
os.mkdir(log_path)
|
os.mkdir(log_path)
|
||||||
logger.log_to_file(log_path)
|
logger.log_to_file(log_path)
|
||||||
|
|
||||||
tp = gpc.config.parallel.tensor.mode
|
model = vit_lite_depth7_patch4_32()
|
||||||
|
|
||||||
model = vit_lite_depth7_patch4_32(tensor_parallel=tp)
|
|
||||||
|
|
||||||
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||||
|
|
||||||
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
|
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||||
|
|
||||||
|
@ -107,7 +105,7 @@ def train_cifar():
|
||||||
LogMetricByStepHook(),
|
LogMetricByStepHook(),
|
||||||
# LogTimingByEpochHook(timer=timer, logger=logger),
|
# LogTimingByEpochHook(timer=timer, logger=logger),
|
||||||
# LogMemoryByEpochHook(logger=logger),
|
# LogMemoryByEpochHook(logger=logger),
|
||||||
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
|
AccuracyHook(accuracy_func=Accuracy()),
|
||||||
LossHook(),
|
LossHook(),
|
||||||
ThroughputHook(),
|
ThroughputHook(),
|
||||||
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False)
|
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False)
|
||||||
|
|
|
@ -4,7 +4,7 @@ TOTAL_BATCH_SIZE = 4096
|
||||||
LEARNING_RATE = 3e-3
|
LEARNING_RATE = 3e-3
|
||||||
WEIGHT_DECAY = 0.3
|
WEIGHT_DECAY = 0.3
|
||||||
|
|
||||||
TENSOR_PARALLEL_SIZE = 4
|
TENSOR_PARALLEL_SIZE = 2
|
||||||
TENSOR_PARALLEL_MODE = '1d'
|
TENSOR_PARALLEL_MODE = '1d'
|
||||||
|
|
||||||
NUM_EPOCHS = 300
|
NUM_EPOCHS = 300
|
||||||
|
|
|
@ -159,14 +159,12 @@ def train_imagenet():
|
||||||
os.mkdir(log_path)
|
os.mkdir(log_path)
|
||||||
logger.log_to_file(log_path)
|
logger.log_to_file(log_path)
|
||||||
|
|
||||||
tp = gpc.config.parallel.tensor.mode
|
model = vit_small_patch16_224(num_classes=100, init_method='jax')
|
||||||
|
|
||||||
model = vit_small_patch16_224(tensor_parallel=tp, num_classes=100, init_method='jax')
|
|
||||||
|
|
||||||
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||||
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||||
|
|
||||||
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
|
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||||
|
|
||||||
|
@ -192,7 +190,7 @@ def train_imagenet():
|
||||||
LogMetricByStepHook(),
|
LogMetricByStepHook(),
|
||||||
# LogTimingByEpochHook(timer=timer, logger=logger),
|
# LogTimingByEpochHook(timer=timer, logger=logger),
|
||||||
# LogMemoryByEpochHook(logger=logger),
|
# LogMemoryByEpochHook(logger=logger),
|
||||||
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
|
AccuracyHook(accuracy_func=Accuracy()),
|
||||||
LossHook(),
|
LossHook(),
|
||||||
ThroughputHook(),
|
ThroughputHook(),
|
||||||
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
|
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
|
||||||
|
|
|
@ -4,7 +4,7 @@ TOTAL_BATCH_SIZE = 4096
|
||||||
LEARNING_RATE = 3e-3
|
LEARNING_RATE = 3e-3
|
||||||
WEIGHT_DECAY = 0.3
|
WEIGHT_DECAY = 0.3
|
||||||
|
|
||||||
TENSOR_PARALLEL_SIZE = 4
|
TENSOR_PARALLEL_SIZE = 2
|
||||||
TENSOR_PARALLEL_MODE = '1d'
|
TENSOR_PARALLEL_MODE = '1d'
|
||||||
|
|
||||||
NUM_EPOCHS = 300
|
NUM_EPOCHS = 300
|
||||||
|
|
|
@ -159,14 +159,12 @@ def train_imagenet():
|
||||||
os.mkdir(log_path)
|
os.mkdir(log_path)
|
||||||
logger.log_to_file(log_path)
|
logger.log_to_file(log_path)
|
||||||
|
|
||||||
tp = gpc.config.parallel.tensor.mode
|
model = vit_small_patch16_224(num_classes=1000, init_method='jax')
|
||||||
|
|
||||||
model = vit_small_patch16_224(tensor_parallel=tp, num_classes=1000, init_method='jax')
|
|
||||||
|
|
||||||
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||||
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||||
|
|
||||||
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
|
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||||
|
|
||||||
|
@ -192,7 +190,7 @@ def train_imagenet():
|
||||||
LogMetricByStepHook(),
|
LogMetricByStepHook(),
|
||||||
# LogTimingByEpochHook(timer=timer, logger=logger),
|
# LogTimingByEpochHook(timer=timer, logger=logger),
|
||||||
# LogMemoryByEpochHook(logger=logger),
|
# LogMemoryByEpochHook(logger=logger),
|
||||||
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
|
AccuracyHook(accuracy_func=Accuracy()),
|
||||||
LossHook(),
|
LossHook(),
|
||||||
ThroughputHook(),
|
ThroughputHook(),
|
||||||
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
|
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence']
|
ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence']
|
||||||
|
TENSOR_PARALLEL_MODE = 'tensor_parallel_mode'
|
||||||
|
|
||||||
# intializer
|
# intializer
|
||||||
INITIALIZER_MAPPING = {
|
INITIALIZER_MAPPING = {
|
||||||
|
@ -16,6 +17,9 @@ INITIALIZER_MAPPING = {
|
||||||
'sequence': 'Initializer_Sequence'
|
'sequence': 'Initializer_Sequence'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 1D parallel
|
||||||
|
PARALLEL_INPUT_1D = 'parallel_input_1d'
|
||||||
|
|
||||||
# 2D paralllel
|
# 2D paralllel
|
||||||
SUMMA_DIM = 'SUMMA_DIM'
|
SUMMA_DIM = 'SUMMA_DIM'
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,18 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING, TENSOR_PARALLEL_MODE
|
||||||
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
|
|
||||||
from colossalai.context.config import Config
|
from colossalai.context.config import Config
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||||
|
|
||||||
from .parallel_mode import ParallelMode
|
from .parallel_mode import ParallelMode
|
||||||
from .random import add_seed, get_seeds, set_mode
|
from .random import add_seed, get_seeds, set_mode
|
||||||
|
|
||||||
|
@ -386,6 +387,7 @@ class ParallelContext:
|
||||||
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
|
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
|
||||||
tensor_parallel_mode = parallel_config['tensor']['mode']
|
tensor_parallel_mode = parallel_config['tensor']['mode']
|
||||||
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
|
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
|
||||||
|
os.environ[TENSOR_PARALLEL_MODE] = str(tensor_parallel_mode)
|
||||||
self.check_sanity()
|
self.check_sanity()
|
||||||
|
|
||||||
pg_init = []
|
pg_init = []
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
import os
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from colossalai.context import Config
|
from colossalai.context import Config
|
||||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||||
from .process_group_initializer import ProcessGroupInitializer
|
from .process_group_initializer import ProcessGroupInitializer
|
||||||
from ..parallel_mode import ParallelMode
|
from ..parallel_mode import ParallelMode
|
||||||
|
from colossalai.constants import PARALLEL_INPUT_1D
|
||||||
|
|
||||||
|
|
||||||
@DIST_GROUP_INITIALIZER.register_module
|
@DIST_GROUP_INITIALIZER.register_module
|
||||||
|
@ -29,6 +30,7 @@ class Initializer_1D(ProcessGroupInitializer):
|
||||||
process_group = None
|
process_group = None
|
||||||
group_world_size = None
|
group_world_size = None
|
||||||
mode = ParallelMode.PARALLEL_1D
|
mode = ParallelMode.PARALLEL_1D
|
||||||
|
os.environ[PARALLEL_INPUT_1D] = ''
|
||||||
|
|
||||||
for i in range(self.num_group):
|
for i in range(self.num_group):
|
||||||
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
|
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
|
||||||
|
|
|
@ -10,7 +10,7 @@ from typing import Iterable, Union, List, Callable
|
||||||
from .._base_engine import Engine
|
from .._base_engine import Engine
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.nn.layer import split_batch
|
||||||
|
|
||||||
class BaseSchedule(ABC):
|
class BaseSchedule(ABC):
|
||||||
"""A basic helper class to control the process of training or evaluation.
|
"""A basic helper class to control the process of training or evaluation.
|
||||||
|
@ -59,7 +59,11 @@ class BaseSchedule(ABC):
|
||||||
else:
|
else:
|
||||||
data, label = batch_data
|
data, label = batch_data
|
||||||
|
|
||||||
data, label = self._to_list(data), self._to_list(label)
|
if isinstance(label, (tuple, list)):
|
||||||
|
self.batch_size = label[0].size(0)
|
||||||
|
else:
|
||||||
|
self.batch_size = label.size(0)
|
||||||
|
data, label = self._to_list(split_batch(data)), self._to_list(split_batch(label))
|
||||||
return self._move_to_device(data), self._move_to_device(label)
|
return self._move_to_device(data), self._move_to_device(label)
|
||||||
|
|
||||||
def pre_processing(self, engine: Engine):
|
def pre_processing(self, engine: Engine):
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
from .colossalai_layer import *
|
from .colossalai_layer import *
|
||||||
from .fused_bias_gelu import bias_gelu_impl
|
from .parallel_1d import *
|
||||||
|
from .parallel_2d import *
|
||||||
|
from .parallel_2p5d import *
|
||||||
|
from .parallel_3d import *
|
||||||
|
from .parallel_sequence import *
|
||||||
|
from .utils import *
|
||||||
|
from .vanilla import *
|
||||||
from .wrapper import *
|
from .wrapper import *
|
||||||
|
|
|
@ -1,231 +0,0 @@
|
||||||
import math
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
from torch import dtype, nn
|
|
||||||
from torch.nn.modules.activation import *
|
|
||||||
from torch.nn.modules.adaptive import *
|
|
||||||
from torch.nn.modules.batchnorm import *
|
|
||||||
from torch.nn.modules.channelshuffle import *
|
|
||||||
from torch.nn.modules.conv import *
|
|
||||||
from torch.nn.modules.distance import *
|
|
||||||
from torch.nn.modules.dropout import *
|
|
||||||
from torch.nn.modules.flatten import *
|
|
||||||
from torch.nn.modules.fold import *
|
|
||||||
from torch.nn.modules.instancenorm import *
|
|
||||||
from torch.nn.modules.linear import *
|
|
||||||
from torch.nn.modules.normalization import *
|
|
||||||
from torch.nn.modules.padding import *
|
|
||||||
from torch.nn.modules.pixelshuffle import *
|
|
||||||
from torch.nn.modules.pooling import *
|
|
||||||
from torch.nn.modules.rnn import *
|
|
||||||
from torch.nn.modules.sparse import *
|
|
||||||
from torch.nn.modules.transformer import *
|
|
||||||
from torch.nn.modules.upsampling import *
|
|
||||||
|
|
||||||
from .. import init as init
|
|
||||||
|
|
||||||
from .vanilla import *
|
|
||||||
from .parallel_1d import *
|
|
||||||
from .parallel_2d import *
|
|
||||||
from .parallel_2p5d import *
|
|
||||||
from .parallel_3d import *
|
|
||||||
from .parallel_sequence import *
|
|
||||||
|
|
||||||
_parallel_linear = {'1d_col': Linear1D_Col, '1d_row': Linear1D_Row, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
|
|
||||||
|
|
||||||
_parallel_classifier = {
|
|
||||||
None: VanillaClassifier,
|
|
||||||
'1d': VanillaClassifier,
|
|
||||||
'2d': Classifier2D,
|
|
||||||
'2.5d': Classifier2p5D,
|
|
||||||
'3d': Classifier3D
|
|
||||||
}
|
|
||||||
|
|
||||||
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
|
|
||||||
|
|
||||||
_parallel_embedding = {'3d': Embedding3D}
|
|
||||||
|
|
||||||
_parallel_patchembedding = {
|
|
||||||
None: VanillaPatchEmbedding,
|
|
||||||
'1d': VanillaPatchEmbedding,
|
|
||||||
'2d': PatchEmbedding2D,
|
|
||||||
'2.5d': PatchEmbedding2p5D,
|
|
||||||
'3d': PatchEmbedding3D
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class Linear(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
in_features: int,
|
|
||||||
out_features: int,
|
|
||||||
bias: bool = True,
|
|
||||||
dtype: dtype = None,
|
|
||||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
|
||||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
|
||||||
tensor_parallel: Optional[str] = None,
|
|
||||||
**kwargs) -> None:
|
|
||||||
super().__init__()
|
|
||||||
if tensor_parallel is None:
|
|
||||||
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
|
|
||||||
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
|
|
||||||
if bias:
|
|
||||||
bias_initializer(self.layer.bias, fan_in=in_features)
|
|
||||||
else:
|
|
||||||
self.layer = _parallel_linear[tensor_parallel](
|
|
||||||
in_features,
|
|
||||||
out_features,
|
|
||||||
bias=bias,
|
|
||||||
dtype=dtype,
|
|
||||||
weight_initializer=weight_initializer,
|
|
||||||
bias_initializer=bias_initializer,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def weight(self):
|
|
||||||
return self.layer.weight
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bias(self):
|
|
||||||
return self.layer.bias
|
|
||||||
|
|
||||||
def forward(self, *args):
|
|
||||||
return self.layer(*args)
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
|
||||||
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None, tensor_parallel: Optional[str] = None) -> None:
|
|
||||||
super().__init__()
|
|
||||||
if tensor_parallel in [None, '1d']:
|
|
||||||
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
|
|
||||||
else:
|
|
||||||
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def weight(self):
|
|
||||||
return self.norm.weight
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bias(self):
|
|
||||||
return self.norm.bias
|
|
||||||
|
|
||||||
def forward(self, *args):
|
|
||||||
return self.norm(*args)
|
|
||||||
|
|
||||||
|
|
||||||
class Embedding(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
num_embeddings: int,
|
|
||||||
embedding_dim: int,
|
|
||||||
padding_idx: int = None,
|
|
||||||
dtype: dtype = None,
|
|
||||||
weight_initializer: Callable = init.normal_(),
|
|
||||||
tensor_parallel: Optional[str] = None,
|
|
||||||
*args,
|
|
||||||
**kwargs) -> None:
|
|
||||||
super().__init__()
|
|
||||||
if tensor_parallel in [None, '1d']:
|
|
||||||
self.embed = nn.Embedding(num_embeddings,
|
|
||||||
embedding_dim,
|
|
||||||
padding_idx=padding_idx,
|
|
||||||
device=get_current_device(),
|
|
||||||
dtype=dtype,
|
|
||||||
*args,
|
|
||||||
**kwargs)
|
|
||||||
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
|
|
||||||
else:
|
|
||||||
self.embed = _parallel_embedding[tensor_parallel](
|
|
||||||
num_embeddings,
|
|
||||||
embedding_dim,
|
|
||||||
padding_idx=padding_idx,
|
|
||||||
dtype=dtype,
|
|
||||||
weight_initializer=weight_initializer,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def weight(self):
|
|
||||||
return self.embed.weight
|
|
||||||
|
|
||||||
def forward(self, *args):
|
|
||||||
return self.embed(*args)
|
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbedding(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
img_size: int,
|
|
||||||
patch_size: int,
|
|
||||||
in_chans: int,
|
|
||||||
embed_size: int,
|
|
||||||
dtype: dtype = None,
|
|
||||||
flatten: bool = True,
|
|
||||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
|
||||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
|
||||||
position_embed_initializer: Callable = init.zeros_(),
|
|
||||||
tensor_parallel: Optional[str] = None) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.embed = _parallel_patchembedding[tensor_parallel](
|
|
||||||
img_size,
|
|
||||||
patch_size,
|
|
||||||
in_chans,
|
|
||||||
embed_size,
|
|
||||||
dtype=dtype,
|
|
||||||
flatten=flatten,
|
|
||||||
weight_initializer=weight_initializer,
|
|
||||||
bias_initializer=bias_initializer,
|
|
||||||
position_embed_initializer=position_embed_initializer,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def weight(self):
|
|
||||||
return self.embed.weight
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bias(self):
|
|
||||||
return self.embed.bias
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pos_embed(self):
|
|
||||||
return self.embed.pos_embed
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cls_token(self):
|
|
||||||
return self.embed.cls_token
|
|
||||||
|
|
||||||
def forward(self, *args):
|
|
||||||
return self.embed(*args)
|
|
||||||
|
|
||||||
|
|
||||||
class Classifier(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
in_features: int,
|
|
||||||
num_classes: int,
|
|
||||||
weight: nn.Parameter = None,
|
|
||||||
bias: bool = True,
|
|
||||||
dtype: dtype = None,
|
|
||||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
|
||||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
|
||||||
tensor_parallel: Optional[str] = None) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.layer = _parallel_classifier[tensor_parallel](
|
|
||||||
in_features,
|
|
||||||
num_classes,
|
|
||||||
weight=weight,
|
|
||||||
bias=bias,
|
|
||||||
dtype=dtype,
|
|
||||||
weight_initializer=weight_initializer,
|
|
||||||
bias_initializer=bias_initializer,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def weight(self):
|
|
||||||
return self.layer.weight
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bias(self):
|
|
||||||
return self.layer.bias
|
|
||||||
|
|
||||||
def forward(self, *args):
|
|
||||||
return self.layer(*args)
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
from ._utils import split_batch
|
||||||
|
from .dropout import Dropout
|
||||||
|
from .embedding import Embedding, PatchEmbedding
|
||||||
|
from .linear import Classifier, Linear
|
||||||
|
from .normalization import LayerNorm
|
||||||
|
|
||||||
|
__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'split_batch']
|
|
@ -0,0 +1,19 @@
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from ..parallel_2d._operation import split_tensor_2d
|
||||||
|
from ..parallel_2p5d._operation import split_tensor_2p5d
|
||||||
|
from ..parallel_3d._operation import split_tensor_3d
|
||||||
|
from ..utils import get_tensor_parallel_mode
|
||||||
|
|
||||||
|
_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_tensor_3d}
|
||||||
|
|
||||||
|
|
||||||
|
def split_batch(input_) -> Tensor:
|
||||||
|
tensor_parallel_mode = get_tensor_parallel_mode()
|
||||||
|
if tensor_parallel_mode in _parallel_split_batch:
|
||||||
|
if isinstance(input_, (tuple, list)):
|
||||||
|
return tuple(map(_parallel_split_batch[tensor_parallel_mode], input_))
|
||||||
|
else:
|
||||||
|
return _parallel_split_batch[tensor_parallel_mode](input_)
|
||||||
|
else:
|
||||||
|
return input_
|
|
@ -0,0 +1,23 @@
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from colossalai.context import ParallelMode, seed
|
||||||
|
from colossalai.utils import conditional_context
|
||||||
|
|
||||||
|
from ..parallel_1d import *
|
||||||
|
from ..utils import get_tensor_parallel_mode
|
||||||
|
|
||||||
|
|
||||||
|
class Dropout(nn.Module):
|
||||||
|
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.tensor_parallel = get_tensor_parallel_mode()
|
||||||
|
if self.tensor_parallel == '1d':
|
||||||
|
self.drop = Dropout1D(p, inplace)
|
||||||
|
else:
|
||||||
|
self.drop = nn.Dropout(p, inplace)
|
||||||
|
|
||||||
|
def forward(self, *args):
|
||||||
|
cm = nullcontext() if self.tensor_parallel in ['None', '1d'] else seed(ParallelMode.TENSOR)
|
||||||
|
with cm:
|
||||||
|
return self.drop(*args)
|
|
@ -0,0 +1,107 @@
|
||||||
|
import math
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from torch import dtype, nn
|
||||||
|
|
||||||
|
from ... import init as init
|
||||||
|
from ..parallel_1d import *
|
||||||
|
from ..parallel_2d import *
|
||||||
|
from ..parallel_2p5d import *
|
||||||
|
from ..parallel_3d import *
|
||||||
|
from ..utils import get_tensor_parallel_mode
|
||||||
|
from ..vanilla import *
|
||||||
|
|
||||||
|
_parallel_embedding = {'1d': Embedding1D, '2d': Embedding2D, '2.5d': Embedding2p5D, '3d': Embedding3D}
|
||||||
|
|
||||||
|
_parallel_patchembedding = {
|
||||||
|
'None': VanillaPatchEmbedding,
|
||||||
|
'1d': VanillaPatchEmbedding,
|
||||||
|
'2d': PatchEmbedding2D,
|
||||||
|
'2.5d': PatchEmbedding2p5D,
|
||||||
|
'3d': PatchEmbedding3D
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
num_embeddings: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
padding_idx: int = None,
|
||||||
|
dtype: dtype = None,
|
||||||
|
weight_initializer: Callable = init.normal_(),
|
||||||
|
*args,
|
||||||
|
**kwargs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
tensor_parallel = get_tensor_parallel_mode()
|
||||||
|
if tensor_parallel == 'None':
|
||||||
|
self.embed = nn.Embedding(num_embeddings,
|
||||||
|
embedding_dim,
|
||||||
|
padding_idx=padding_idx,
|
||||||
|
device=get_current_device(),
|
||||||
|
dtype=dtype,
|
||||||
|
*args,
|
||||||
|
**kwargs)
|
||||||
|
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
|
||||||
|
else:
|
||||||
|
self.embed = _parallel_embedding[tensor_parallel](
|
||||||
|
num_embeddings,
|
||||||
|
embedding_dim,
|
||||||
|
padding_idx=padding_idx,
|
||||||
|
dtype=dtype,
|
||||||
|
weight_initializer=weight_initializer,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight(self):
|
||||||
|
return self.embed.weight
|
||||||
|
|
||||||
|
def forward(self, *args):
|
||||||
|
return self.embed(*args)
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbedding(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
img_size: int,
|
||||||
|
patch_size: int,
|
||||||
|
in_chans: int,
|
||||||
|
embed_size: int,
|
||||||
|
dtype: dtype = None,
|
||||||
|
flatten: bool = True,
|
||||||
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||||
|
position_embed_initializer: Callable = init.zeros_()) -> None:
|
||||||
|
super().__init__()
|
||||||
|
tensor_parallel = get_tensor_parallel_mode()
|
||||||
|
self.embed = _parallel_patchembedding[tensor_parallel](
|
||||||
|
img_size,
|
||||||
|
patch_size,
|
||||||
|
in_chans,
|
||||||
|
embed_size,
|
||||||
|
dtype=dtype,
|
||||||
|
flatten=flatten,
|
||||||
|
weight_initializer=weight_initializer,
|
||||||
|
bias_initializer=bias_initializer,
|
||||||
|
position_embed_initializer=position_embed_initializer,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight(self):
|
||||||
|
return self.embed.weight
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bias(self):
|
||||||
|
return self.embed.bias
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pos_embed(self):
|
||||||
|
return self.embed.pos_embed
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cls_token(self):
|
||||||
|
return self.embed.cls_token
|
||||||
|
|
||||||
|
def forward(self, *args):
|
||||||
|
return self.embed(*args)
|
|
@ -0,0 +1,97 @@
|
||||||
|
import math
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from colossalai.nn.layer.parallel_1d.layers import Classifier1D
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from torch import dtype, nn
|
||||||
|
|
||||||
|
from ... import init as init
|
||||||
|
from ..parallel_1d import *
|
||||||
|
from ..parallel_2d import *
|
||||||
|
from ..parallel_2p5d import *
|
||||||
|
from ..parallel_3d import *
|
||||||
|
from ..utils import get_tensor_parallel_mode
|
||||||
|
from ..vanilla import *
|
||||||
|
|
||||||
|
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
|
||||||
|
|
||||||
|
_parallel_classifier = {
|
||||||
|
'None': VanillaClassifier,
|
||||||
|
'1d': Classifier1D,
|
||||||
|
'2d': Classifier2D,
|
||||||
|
'2.5d': Classifier2p5D,
|
||||||
|
'3d': Classifier3D
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
bias: bool = True,
|
||||||
|
dtype: dtype = None,
|
||||||
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||||
|
**kwargs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
tensor_parallel = get_tensor_parallel_mode()
|
||||||
|
if tensor_parallel == 'None':
|
||||||
|
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
|
||||||
|
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
|
||||||
|
if bias:
|
||||||
|
bias_initializer(self.layer.bias, fan_in=in_features)
|
||||||
|
else:
|
||||||
|
self.layer = _parallel_linear[tensor_parallel](
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
bias=bias,
|
||||||
|
dtype=dtype,
|
||||||
|
weight_initializer=weight_initializer,
|
||||||
|
bias_initializer=bias_initializer,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight(self):
|
||||||
|
return self.layer.weight
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bias(self):
|
||||||
|
return self.layer.bias
|
||||||
|
|
||||||
|
def forward(self, *args):
|
||||||
|
return self.layer(*args)
|
||||||
|
|
||||||
|
|
||||||
|
class Classifier(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
num_classes: int,
|
||||||
|
weight: nn.Parameter = None,
|
||||||
|
bias: bool = True,
|
||||||
|
dtype: dtype = None,
|
||||||
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.layer = _parallel_classifier[get_tensor_parallel_mode()](
|
||||||
|
in_features,
|
||||||
|
num_classes,
|
||||||
|
weight=weight,
|
||||||
|
bias=bias,
|
||||||
|
dtype=dtype,
|
||||||
|
weight_initializer=weight_initializer,
|
||||||
|
bias_initializer=bias_initializer,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight(self):
|
||||||
|
return self.layer.weight
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bias(self):
|
||||||
|
return self.layer.bias
|
||||||
|
|
||||||
|
def forward(self, *args):
|
||||||
|
return self.layer(*args)
|
|
@ -0,0 +1,35 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ... import init as init
|
||||||
|
from ..parallel_1d import *
|
||||||
|
from ..parallel_2d import *
|
||||||
|
from ..parallel_2p5d import *
|
||||||
|
from ..parallel_3d import *
|
||||||
|
from ..utils import get_tensor_parallel_mode
|
||||||
|
from ..vanilla import *
|
||||||
|
|
||||||
|
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
tensor_parallel = get_tensor_parallel_mode()
|
||||||
|
if tensor_parallel in ['None', '1d']:
|
||||||
|
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
|
||||||
|
else:
|
||||||
|
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight(self):
|
||||||
|
return self.norm.weight
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bias(self):
|
||||||
|
return self.norm.bias
|
||||||
|
|
||||||
|
def forward(self, *args):
|
||||||
|
return self.norm(*args)
|
|
@ -1,35 +0,0 @@
|
||||||
# adapted from Megatron-LM
|
|
||||||
# https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/megatron/model/fused_bias_gelu.py
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def bias_gelu(bias, y):
|
|
||||||
x = bias + y
|
|
||||||
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
|
||||||
|
|
||||||
# gradient of tanh approximation of gelu
|
|
||||||
# gradient of actual gelu is:
|
|
||||||
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
|
||||||
@torch.jit.script
|
|
||||||
def bias_gelu_back(g, bias, y):
|
|
||||||
x = bias + y
|
|
||||||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
|
||||||
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
|
||||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
|
||||||
return ff*g
|
|
||||||
|
|
||||||
class GeLUFunction(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
# bias is an optional argument
|
|
||||||
def forward(ctx, input, bias):
|
|
||||||
ctx.save_for_backward(input, bias)
|
|
||||||
return bias_gelu(bias, input)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
input, bias = ctx.saved_tensors
|
|
||||||
tmp = bias_gelu_back(grad_output, bias, input)
|
|
||||||
return tmp, tmp
|
|
||||||
|
|
||||||
bias_gelu_impl = GeLUFunction.apply
|
|
|
@ -1,4 +1,4 @@
|
||||||
from .layers import Linear1D_Col, Linear1D_Row
|
from .layers import Dropout1D, Embedding1D, Linear1D, Linear1D_Col, Linear1D_Row
|
||||||
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
|
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
|
||||||
|
|
||||||
__all__ = ['Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D']
|
__all__ = ['Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D', 'Embedding1D', 'Dropout1D']
|
||||||
|
|
|
@ -1,12 +1,21 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from colossalai.constants import PARALLEL_INPUT_1D
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
from .._common_utils import divide
|
from ..utils import divide
|
||||||
|
|
||||||
|
|
||||||
|
def set_parallel_input(input_parallel: bool):
|
||||||
|
os.environ[PARALLEL_INPUT_1D] = 'true' if input_parallel else ''
|
||||||
|
|
||||||
|
|
||||||
|
def get_parallel_input():
|
||||||
|
return bool(os.environ[PARALLEL_INPUT_1D])
|
||||||
|
|
||||||
|
|
||||||
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):
|
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):
|
||||||
|
|
|
@ -3,10 +3,10 @@
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import numbers
|
import numbers
|
||||||
|
from contextlib import nullcontext
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from colossalai.communication import broadcast
|
from colossalai.communication import broadcast
|
||||||
from colossalai.context import ParallelMode, seed
|
from colossalai.context import ParallelMode, seed
|
||||||
|
@ -14,13 +14,122 @@ from colossalai.core import global_context as gpc
|
||||||
from colossalai.nn import init as init
|
from colossalai.nn import init as init
|
||||||
from colossalai.registry import LAYERS
|
from colossalai.registry import LAYERS
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from torch import Tensor
|
from torch import Tensor, dtype
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
|
|
||||||
from ..base_layer import ParallelLayer
|
from ..base_layer import ParallelLayer
|
||||||
|
from ..utils import divide, set_tensor_parallel_attribute_by_partition
|
||||||
from ._operation import FusedLayerNormAffineFunction1D
|
from ._operation import FusedLayerNormAffineFunction1D
|
||||||
from ._utils import (gather_forward_split_backward, reduce_grad, reduce_input, split_forward_gather_backward)
|
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
|
||||||
|
split_forward_gather_backward)
|
||||||
|
|
||||||
|
|
||||||
|
@LAYERS.register_module
|
||||||
|
class Linear1D(torch.nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
bias: bool = True,
|
||||||
|
dtype: torch.dtype = None,
|
||||||
|
gather_output: bool = False,
|
||||||
|
skip_bias_add: bool = False,
|
||||||
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||||
|
super().__init__()
|
||||||
|
parallel_input = get_parallel_input()
|
||||||
|
if not parallel_input:
|
||||||
|
self.layer = Linear1D_Col(in_features,
|
||||||
|
out_features,
|
||||||
|
bias=bias,
|
||||||
|
dtype=dtype,
|
||||||
|
gather_output=gather_output,
|
||||||
|
skip_bias_add=skip_bias_add,
|
||||||
|
weight_initializer=weight_initializer,
|
||||||
|
bias_initializer=bias_initializer)
|
||||||
|
else:
|
||||||
|
self.layer = Linear1D_Row(in_features,
|
||||||
|
out_features,
|
||||||
|
bias=bias,
|
||||||
|
dtype=dtype,
|
||||||
|
parallel_input=parallel_input,
|
||||||
|
skip_bias_add=skip_bias_add,
|
||||||
|
weight_initializer=weight_initializer,
|
||||||
|
bias_initializer=bias_initializer)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight(self):
|
||||||
|
return self.layer.weight
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bias(self):
|
||||||
|
return self.layer.bias
|
||||||
|
|
||||||
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
|
return self.layer(input_)
|
||||||
|
|
||||||
|
|
||||||
|
@LAYERS.register_module
|
||||||
|
class Classifier1D(ParallelLayer):
|
||||||
|
"""RowLinear with given weight"""
|
||||||
|
def __init__(self,
|
||||||
|
in_features: int,
|
||||||
|
num_classes: int,
|
||||||
|
weight: Parameter = None,
|
||||||
|
bias: bool = True,
|
||||||
|
dtype: dtype = None,
|
||||||
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||||
|
super().__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.parallel_input = get_parallel_input()
|
||||||
|
|
||||||
|
# Divide the weight matrix along the last dimension.
|
||||||
|
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
|
||||||
|
|
||||||
|
# Parameters.
|
||||||
|
# Initialize weight.
|
||||||
|
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||||
|
if weight is not None:
|
||||||
|
self.weight = weight
|
||||||
|
self.has_weight = False
|
||||||
|
else:
|
||||||
|
self.weight = Parameter(torch.empty(self.num_classes, self.input_size_per_partition, **factory_kwargs))
|
||||||
|
self.has_weight = True
|
||||||
|
if bias:
|
||||||
|
self.bias = Parameter(torch.empty(self.num_classes, **factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
with seed(ParallelMode.TENSOR):
|
||||||
|
self.reset_parameters(weight_initializer, bias_initializer)
|
||||||
|
self._set_tensor_parallel_attributes()
|
||||||
|
set_parallel_input(False)
|
||||||
|
|
||||||
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||||
|
fan_in, fan_out = self.in_features, self.num_classes
|
||||||
|
if self.has_weight:
|
||||||
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||||
|
if self.bias is not None:
|
||||||
|
bias_initializer(self.bias, fan_in=fan_in)
|
||||||
|
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D)
|
||||||
|
|
||||||
|
def _set_tensor_parallel_attributes(self):
|
||||||
|
if self.has_weight:
|
||||||
|
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
|
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
||||||
|
|
||||||
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
|
# Set up backprop all-reduce.
|
||||||
|
if self.parallel_input:
|
||||||
|
input_ = input_
|
||||||
|
else:
|
||||||
|
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
||||||
|
|
||||||
|
output_parallel = F.linear(input_, self.weight)
|
||||||
|
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
||||||
|
|
||||||
|
output = output + self.bias
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
|
@ -77,6 +186,7 @@ class Linear1D_Col(ParallelLayer):
|
||||||
with seed(ParallelMode.TENSOR):
|
with seed(ParallelMode.TENSOR):
|
||||||
self.reset_parameters(weight_initializer, bias_initializer)
|
self.reset_parameters(weight_initializer, bias_initializer)
|
||||||
self._set_tensor_parallel_attributes()
|
self._set_tensor_parallel_attributes()
|
||||||
|
set_parallel_input(True)
|
||||||
|
|
||||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||||
fan_in, fan_out = self.in_features, self.out_features
|
fan_in, fan_out = self.in_features, self.out_features
|
||||||
|
@ -158,6 +268,7 @@ class Linear1D_Row(ParallelLayer):
|
||||||
with seed(ParallelMode.TENSOR):
|
with seed(ParallelMode.TENSOR):
|
||||||
self.reset_parameters(weight_initializer, bias_initializer)
|
self.reset_parameters(weight_initializer, bias_initializer)
|
||||||
self._set_tensor_parallel_attributes()
|
self._set_tensor_parallel_attributes()
|
||||||
|
set_parallel_input(False)
|
||||||
|
|
||||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||||
fan_in, fan_out = self.in_features, self.out_features
|
fan_in, fan_out = self.in_features, self.out_features
|
||||||
|
@ -208,3 +319,68 @@ class MixedFusedLayerNorm1D(torch.nn.Module):
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return FusedLayerNormAffineFunction1D.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
|
return FusedLayerNormAffineFunction1D.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
@LAYERS.register_module
|
||||||
|
class Embedding1D(ParallelLayer):
|
||||||
|
def __init__(self,
|
||||||
|
num_embeddings: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
padding_idx: int = None,
|
||||||
|
dtype: dtype = None,
|
||||||
|
weight_initializer: Callable = init.normal_(),
|
||||||
|
*args,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_embeddings = num_embeddings
|
||||||
|
self.embed_dim = embedding_dim
|
||||||
|
embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size)
|
||||||
|
|
||||||
|
self.padding_idx = padding_idx
|
||||||
|
self.embed_args = args
|
||||||
|
self.embed_kwargs = kwargs
|
||||||
|
|
||||||
|
self.weight = Parameter(
|
||||||
|
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
|
||||||
|
|
||||||
|
self.reset_parameters(weight_initializer)
|
||||||
|
self._set_tensor_parallel_attributes()
|
||||||
|
set_parallel_input(False)
|
||||||
|
|
||||||
|
def _set_tensor_parallel_attributes(self):
|
||||||
|
set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size)
|
||||||
|
|
||||||
|
def reset_parameters(self, weight_initializer) -> None:
|
||||||
|
with seed(ParallelMode.TENSOR):
|
||||||
|
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
||||||
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||||
|
self._fill_padding_idx_with_zero()
|
||||||
|
|
||||||
|
def _fill_padding_idx_with_zero(self) -> None:
|
||||||
|
if self.padding_idx is not None:
|
||||||
|
with torch.no_grad():
|
||||||
|
self.weight[self.padding_idx].fill_(0)
|
||||||
|
|
||||||
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
|
|
||||||
|
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||||
|
|
||||||
|
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@LAYERS.register_module
|
||||||
|
class Dropout1D(ParallelLayer):
|
||||||
|
def __init__(self, p: float = 0.5, inplace: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.parallel_input = get_parallel_input()
|
||||||
|
self.p = p
|
||||||
|
self.inplace = inplace
|
||||||
|
|
||||||
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
|
cm = nullcontext() if not self.parallel_input else seed(ParallelMode.TENSOR)
|
||||||
|
with cm:
|
||||||
|
output = F.dropout(input_, self.p, self.training, self.inplace)
|
||||||
|
return output
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from ._operation import reduce_by_batch_2d, split_batch_2d
|
from ._operation import reduce_by_batch_2d, split_tensor_2d
|
||||||
from .layers import Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D
|
from .layers import Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D'
|
'split_tensor_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D'
|
||||||
]
|
]
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.communication.collective import (all_gather, all_reduce, reduce_scatter)
|
from colossalai.communication.collective import (all_gather, all_reduce, reduce, reduce_scatter)
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
@ -595,7 +595,9 @@ class SplitFirst(torch.autograd.Function):
|
||||||
return grad, None, None
|
return grad, None, None
|
||||||
|
|
||||||
|
|
||||||
def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||||
|
if input_.size(dim) <= 1:
|
||||||
|
return input_
|
||||||
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL),
|
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL),
|
||||||
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous()
|
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous()
|
||||||
|
|
||||||
|
@ -603,17 +605,28 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||||
class reduce_by_batch_2d(torch.autograd.Function):
|
class reduce_by_batch_2d(torch.autograd.Function):
|
||||||
"""All-reduce the input from the model parallel region."""
|
"""All-reduce the input from the model parallel region."""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def symbolic(graph, input_):
|
def symbolic(graph, input_, reduce_mean: bool = False):
|
||||||
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
|
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
|
||||||
return input_
|
if reduce_mean:
|
||||||
|
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
|
||||||
|
return output / reduce_size
|
||||||
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float32)
|
@custom_fwd(cast_inputs=torch.float32)
|
||||||
def forward(ctx, input_):
|
def forward(ctx, input_, reduce_mean: bool = False):
|
||||||
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
|
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
|
||||||
return input_.clone()
|
ctx.reduce_mean = reduce_mean
|
||||||
|
if reduce_mean:
|
||||||
|
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
|
||||||
|
ctx.reduce_size = reduce_size
|
||||||
|
return output.clone() / reduce_size
|
||||||
|
return output.clone()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_bwd
|
@custom_bwd
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, output_grad):
|
||||||
return grad_output
|
if ctx.reduce_mean:
|
||||||
|
return output_grad / ctx.reduce_size, None
|
||||||
|
else:
|
||||||
|
return output_grad, None
|
||||||
|
|
|
@ -13,9 +13,9 @@ from colossalai.utils import get_current_device
|
||||||
from torch import Tensor, dtype
|
from torch import Tensor, dtype
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
|
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||||
from ..base_layer import ParallelLayer
|
from ..base_layer import ParallelLayer
|
||||||
from ._operation import (Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d, split_batch_2d)
|
from ._operation import Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d
|
||||||
from ._utils import assert_summa_initialization, get_summa_dim_from_env
|
from ._utils import assert_summa_initialization, get_summa_dim_from_env
|
||||||
|
|
||||||
|
|
||||||
|
@ -257,8 +257,6 @@ class PatchEmbedding2D(ParallelLayer):
|
||||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||||
|
|
||||||
input_ = split_batch_2d(input_)
|
|
||||||
|
|
||||||
weight = all_gather_weight_2d.apply(self.weight, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
weight = all_gather_weight_2d.apply(self.weight, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||||
bias = all_gather_weight_2d.apply(self.bias, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
bias = all_gather_weight_2d.apply(self.bias, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||||
|
|
||||||
|
@ -318,8 +316,6 @@ class Embedding2D(ParallelLayer):
|
||||||
self.weight[self.padding_idx].fill_(0)
|
self.weight[self.padding_idx].fill_(0)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
input_ = split_batch_2d(input_)
|
|
||||||
|
|
||||||
weight = all_gather_weight_2d.apply(self.weight, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
weight = all_gather_weight_2d.apply(self.weight, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||||
|
|
||||||
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from ._operation import reduce_by_batch_2p5d, split_batch_2p5d
|
from ._operation import reduce_by_batch_2p5d, split_tensor_2p5d
|
||||||
from .layers import Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D
|
from .layers import Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
|
'split_tensor_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
|
||||||
'Embedding2p5D'
|
'Embedding2p5D'
|
||||||
]
|
]
|
||||||
|
|
|
@ -22,7 +22,7 @@ def get_parallel_rank(parallel_mode: ParallelMode):
|
||||||
return gpc.get_local_rank(parallel_mode)
|
return gpc.get_local_rank(parallel_mode)
|
||||||
|
|
||||||
|
|
||||||
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||||
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
|
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
|
||||||
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
|
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
|
||||||
|
|
||||||
|
@ -120,30 +120,53 @@ class Matmul_AB_2p5D(torch.autograd.Function):
|
||||||
ctx.save_for_backward(A, B)
|
ctx.save_for_backward(A, B)
|
||||||
|
|
||||||
A_shape = A.shape
|
A_shape = A.shape
|
||||||
A = A.reshape((-1, A_shape[-1])).contiguous()
|
A = A.reshape((-1, A_shape[-1]))
|
||||||
B_shape = B.shape
|
B_shape = B.shape
|
||||||
B = B.reshape((-1, B_shape[-1])).contiguous()
|
B = B.reshape((-1, B_shape[-1]))
|
||||||
C_shape = (A.shape[0], B.shape[-1])
|
C_shape = (A.shape[0], B.shape[-1])
|
||||||
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
|
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
|
||||||
|
|
||||||
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode) - 1)]
|
# use circular buffer to store the communication tensor
|
||||||
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode) - 1)]
|
# 2 is enough for all cases
|
||||||
A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
|
A_list = [torch.empty_like(A) for _ in range(2)]
|
||||||
B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
|
B_list = [torch.empty_like(B) for _ in range(2)]
|
||||||
op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
|
|
||||||
op_a.wait()
|
row_group = gpc.get_group(row_parallel_mode)
|
||||||
op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True)
|
col_group = gpc.get_group(col_parallel_mode)
|
||||||
for op in [op_a, op_b]:
|
|
||||||
op.wait()
|
src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
|
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
|
|
||||||
|
opa = [None] * 2
|
||||||
|
opb = [None] * 2
|
||||||
|
|
||||||
|
A_list[0].copy_(A)
|
||||||
|
B_list[0].copy_(B)
|
||||||
|
opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True)
|
||||||
|
opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True)
|
||||||
|
cur = 0
|
||||||
|
|
||||||
for i in range(tesseract_dim):
|
for i in range(tesseract_dim):
|
||||||
src_a = i + tesseract_dim * row_rank
|
if i != tesseract_dim - 1:
|
||||||
src_b = i + tesseract_dim * col_rank
|
A_list[1 - cur].copy_(A)
|
||||||
src_a = src_a % tesseract_dim
|
opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True)
|
||||||
src_b = src_b % tesseract_dim
|
B_list[1 - cur].copy_(B)
|
||||||
A_temp = A_list[src_a]
|
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
|
||||||
B_temp = B_list[src_b]
|
src=src_b + tesseract_dim,
|
||||||
torch.addmm(C, A_temp, B_temp, out=C)
|
group=col_group,
|
||||||
|
async_op=True)
|
||||||
|
|
||||||
|
if opa[cur] is not None:
|
||||||
|
opa[cur].wait()
|
||||||
|
if opb[cur] is not None:
|
||||||
|
opb[cur].wait()
|
||||||
|
|
||||||
|
torch.addmm(C, A_list[cur], B_list[cur], out=C)
|
||||||
|
cur = 1 - cur
|
||||||
|
src_a += 1
|
||||||
|
src_b += tesseract_dim
|
||||||
out = C.reshape(out_shape)
|
out = C.reshape(out_shape)
|
||||||
|
|
||||||
if ctx:
|
if ctx:
|
||||||
|
@ -201,20 +224,55 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
|
||||||
C_shape = (A.shape[0], B.shape[0])
|
C_shape = (A.shape[0], B.shape[0])
|
||||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
||||||
|
|
||||||
for i in range(tesseract_dim):
|
# use circular buffer to store the communication tensor
|
||||||
B_temp = B.clone()
|
# 2 is enough for all cases
|
||||||
src_b = col_rank + i * tesseract_dim + dep_rank * (
|
B_list = [torch.empty_like(B) for _ in range(2)]
|
||||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
C_list = [torch.empty_like(C) for _ in range(2)]
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
|
||||||
dist.broadcast(B_temp, src=src_b, group=gpc.get_group(col_parallel_mode))
|
|
||||||
C_temp = torch.matmul(A, B_temp.transpose(0, 1))
|
|
||||||
src_c = i + row_rank * tesseract_dim + dep_rank * (
|
|
||||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
|
||||||
dist.reduce(C_temp, dst=src_c, group=gpc.get_group(row_parallel_mode))
|
|
||||||
if i == col_rank:
|
|
||||||
C = C_temp.clone()
|
|
||||||
|
|
||||||
|
row_group = gpc.get_group(row_parallel_mode)
|
||||||
|
col_group = gpc.get_group(col_parallel_mode)
|
||||||
|
|
||||||
|
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
|
src_c = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
|
|
||||||
|
opb = [None] * 2
|
||||||
|
opr = [None] * 2
|
||||||
|
|
||||||
|
B_list[0].copy_(B)
|
||||||
|
opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True)
|
||||||
|
cur = 0
|
||||||
|
|
||||||
|
for i in range(tesseract_dim):
|
||||||
|
if i != tesseract_dim - 1:
|
||||||
|
B_list[1 - cur].copy_(B)
|
||||||
|
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
|
||||||
|
src=src_b + tesseract_dim,
|
||||||
|
group=col_group,
|
||||||
|
async_op=True)
|
||||||
|
|
||||||
|
if opr[cur] is not None:
|
||||||
|
opr[cur].wait()
|
||||||
|
if i - 2 == col_rank:
|
||||||
|
C.copy_(C_list[cur])
|
||||||
|
|
||||||
|
if opb[cur] is not None:
|
||||||
|
opb[cur].wait()
|
||||||
|
|
||||||
|
torch.matmul(A, B_list[cur].transpose(0, 1), out=C_list[cur])
|
||||||
|
opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=row_group, async_op=True)
|
||||||
|
cur = 1 - cur
|
||||||
|
src_b += tesseract_dim
|
||||||
|
src_c += 1
|
||||||
|
|
||||||
|
for op in opr:
|
||||||
|
op.wait()
|
||||||
|
|
||||||
|
if tesseract_dim - 2 == col_rank:
|
||||||
|
C.copy_(C_list[cur])
|
||||||
|
if tesseract_dim - 1 == col_rank:
|
||||||
|
C.copy_(C_list[1 - cur])
|
||||||
out = C.reshape(out_shape)
|
out = C.reshape(out_shape)
|
||||||
|
|
||||||
if ctx:
|
if ctx:
|
||||||
|
@ -272,20 +330,52 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
|
||||||
C_shape = (A.shape[-1], B.shape[-1])
|
C_shape = (A.shape[-1], B.shape[-1])
|
||||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
||||||
|
|
||||||
for i in range(tesseract_dim):
|
# use circular buffer to store the communication tensor
|
||||||
A_temp = A.clone()
|
# 2 is enough for all cases
|
||||||
src_a = i + row_rank * tesseract_dim + dep_rank * (
|
A_list = [torch.empty_like(A) for _ in range(2)]
|
||||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
C_list = [torch.empty_like(C) for _ in range(2)]
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
|
||||||
dist.broadcast(A_temp, src=src_a, group=get_parallel_group(row_parallel_mode))
|
|
||||||
C_temp = torch.matmul(A_temp.transpose(0, 1), B)
|
|
||||||
src_c = col_rank + i * tesseract_dim + dep_rank * (
|
|
||||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
|
||||||
dist.reduce(C_temp, dst=src_c, group=get_parallel_group(col_parallel_mode))
|
|
||||||
if i == row_rank:
|
|
||||||
C = C_temp.clone()
|
|
||||||
|
|
||||||
|
row_group = gpc.get_group(row_parallel_mode)
|
||||||
|
col_group = gpc.get_group(col_parallel_mode)
|
||||||
|
|
||||||
|
src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
|
src_c = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
|
|
||||||
|
opa = [None] * 2
|
||||||
|
opr = [None] * 2
|
||||||
|
|
||||||
|
A_list[0].copy_(A)
|
||||||
|
opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True)
|
||||||
|
cur = 0
|
||||||
|
|
||||||
|
for i in range(tesseract_dim):
|
||||||
|
if i != tesseract_dim - 1:
|
||||||
|
A_list[1 - cur].copy_(A)
|
||||||
|
opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True)
|
||||||
|
|
||||||
|
if opr[cur] is not None:
|
||||||
|
opr[cur].wait()
|
||||||
|
if i - 2 == row_rank:
|
||||||
|
C.copy_(C_list[cur])
|
||||||
|
|
||||||
|
if opa[cur] is not None:
|
||||||
|
opa[cur].wait()
|
||||||
|
|
||||||
|
torch.matmul(A_list[cur].transpose(0, 1), B, out=C_list[cur])
|
||||||
|
opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=col_group, async_op=True)
|
||||||
|
cur = 1 - cur
|
||||||
|
src_a += 1
|
||||||
|
src_c += tesseract_dim
|
||||||
|
|
||||||
|
for op in opr:
|
||||||
|
op.wait()
|
||||||
|
|
||||||
|
if tesseract_dim - 2 == row_rank:
|
||||||
|
C.copy_(C_list[cur])
|
||||||
|
if tesseract_dim - 1 == row_rank:
|
||||||
|
C.copy_(C_list[1 - cur])
|
||||||
out = C.reshape(out_shape)
|
out = C.reshape(out_shape)
|
||||||
|
|
||||||
if ctx:
|
if ctx:
|
||||||
|
@ -333,8 +423,7 @@ class Add_Bias_2p5D(torch.autograd.Function):
|
||||||
bias_temp = bias.clone()
|
bias_temp = bias.clone()
|
||||||
else:
|
else:
|
||||||
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device())
|
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device())
|
||||||
src_rank = col_rank + dep_rank * (
|
src_rank = col_rank + dep_rank * tesseract_dim ** 2 + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode))
|
dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode))
|
||||||
|
|
||||||
|
@ -469,7 +558,9 @@ class SplitFirst(torch.autograd.Function):
|
||||||
return grad, None, None
|
return grad, None, None
|
||||||
|
|
||||||
|
|
||||||
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||||
|
if input_.size(dim) <= 1:
|
||||||
|
return input_
|
||||||
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
|
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
|
||||||
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
|
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
|
||||||
|
|
||||||
|
@ -477,17 +568,28 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||||
class reduce_by_batch_2p5d(torch.autograd.Function):
|
class reduce_by_batch_2p5d(torch.autograd.Function):
|
||||||
"""All-reduce the input from the model parallel region."""
|
"""All-reduce the input from the model parallel region."""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def symbolic(graph, input_):
|
def symbolic(graph, input_, reduce_mean: bool = False):
|
||||||
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL))
|
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
|
||||||
return input_
|
if reduce_mean:
|
||||||
|
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
|
||||||
|
return output / reduce_size
|
||||||
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float32)
|
@custom_fwd(cast_inputs=torch.float32)
|
||||||
def forward(ctx, input_):
|
def forward(ctx, input_, reduce_mean: bool = False):
|
||||||
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL))
|
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
|
||||||
return input_.clone()
|
ctx.reduce_mean = reduce_mean
|
||||||
|
if reduce_mean:
|
||||||
|
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
|
||||||
|
ctx.reduce_size = reduce_size
|
||||||
|
return output.clone() / reduce_size
|
||||||
|
return output.clone()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_bwd
|
@custom_bwd
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, output_grad):
|
||||||
return grad_output
|
if ctx.reduce_mean:
|
||||||
|
return output_grad / ctx.reduce_size, None
|
||||||
|
else:
|
||||||
|
return output_grad, None
|
||||||
|
|
|
@ -13,10 +13,9 @@ from colossalai.utils import get_current_device
|
||||||
from torch import Tensor, dtype
|
from torch import Tensor, dtype
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
|
|
||||||
from ..base_layer import ParallelLayer
|
from ..base_layer import ParallelLayer
|
||||||
from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d,
|
from ..utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
|
||||||
split_batch_2p5d)
|
from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d)
|
||||||
from ._utils import (assert_tesseract_initialization, get_tesseract_dim_dep_from_env)
|
from ._utils import (assert_tesseract_initialization, get_tesseract_dim_dep_from_env)
|
||||||
|
|
||||||
|
|
||||||
|
@ -231,7 +230,7 @@ class PatchEmbedding2p5D(ParallelLayer):
|
||||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||||
self.flatten = flatten
|
self.flatten = flatten
|
||||||
self.embed_size = embed_size
|
self.embed_size = embed_size
|
||||||
self.embed_size_per_partition = embed_size // (self.tesseract_dep * self.tesseract_dim**2)
|
self.embed_size_per_partition = embed_size // self.tesseract_dim**2
|
||||||
|
|
||||||
with seed(ParallelMode.TENSOR):
|
with seed(ParallelMode.TENSOR):
|
||||||
self.weight = Parameter(
|
self.weight = Parameter(
|
||||||
|
@ -251,10 +250,10 @@ class PatchEmbedding2p5D(ParallelLayer):
|
||||||
self._set_tensor_parallel_attribute()
|
self._set_tensor_parallel_attribute()
|
||||||
|
|
||||||
def _set_tensor_parallel_attribute(self):
|
def _set_tensor_parallel_attribute(self):
|
||||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
|
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
|
||||||
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dep * self.tesseract_dim**2)
|
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim**2)
|
||||||
set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dep * self.tesseract_dim**2)
|
set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dim**2)
|
||||||
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dep * self.tesseract_dim**2)
|
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dim**2)
|
||||||
|
|
||||||
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer):
|
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer):
|
||||||
with seed(ParallelMode.TENSOR):
|
with seed(ParallelMode.TENSOR):
|
||||||
|
@ -269,8 +268,6 @@ class PatchEmbedding2p5D(ParallelLayer):
|
||||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||||
|
|
||||||
input_ = split_batch_2p5d(input_)
|
|
||||||
|
|
||||||
weight = all_gather_weight_2p5d.apply(self.weight, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
weight = all_gather_weight_2p5d.apply(self.weight, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||||
bias = all_gather_weight_2p5d.apply(self.bias, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
bias = all_gather_weight_2p5d.apply(self.bias, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||||
|
|
||||||
|
@ -303,7 +300,7 @@ class Embedding2p5D(ParallelLayer):
|
||||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||||
self.num_embeddings = num_embeddings
|
self.num_embeddings = num_embeddings
|
||||||
self.embed_dim = embedding_dim
|
self.embed_dim = embedding_dim
|
||||||
embed_dim_per_partition = embedding_dim // (self.tesseract_dep * self.tesseract_dim**2)
|
embed_dim_per_partition = embedding_dim // self.tesseract_dim**2
|
||||||
|
|
||||||
self.padding_idx = padding_idx
|
self.padding_idx = padding_idx
|
||||||
self.embed_args = args
|
self.embed_args = args
|
||||||
|
@ -316,7 +313,7 @@ class Embedding2p5D(ParallelLayer):
|
||||||
self._set_tensor_parallel_attributes()
|
self._set_tensor_parallel_attributes()
|
||||||
|
|
||||||
def _set_tensor_parallel_attributes(self):
|
def _set_tensor_parallel_attributes(self):
|
||||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
|
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
|
||||||
|
|
||||||
def reset_parameters(self, weight_initializer) -> None:
|
def reset_parameters(self, weight_initializer) -> None:
|
||||||
with seed(ParallelMode.TENSOR):
|
with seed(ParallelMode.TENSOR):
|
||||||
|
@ -330,8 +327,6 @@ class Embedding2p5D(ParallelLayer):
|
||||||
self.weight[self.padding_idx].fill_(0)
|
self.weight[self.padding_idx].fill_(0)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
input_ = split_batch_2p5d(input_)
|
|
||||||
|
|
||||||
weight = all_gather_weight_2p5d.apply(self.weight, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
weight = all_gather_weight_2p5d.apply(self.weight, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||||
|
|
||||||
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||||
|
@ -359,7 +354,7 @@ class Classifier2p5D(ParallelLayer):
|
||||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||||
|
|
||||||
# partitioning dimension
|
# partitioning dimension
|
||||||
self.input_size_per_partition = divide(self.in_features, self.tesseract_dep * self.tesseract_dim**2)
|
self.input_size_per_partition = divide(self.in_features, self.tesseract_dim**2)
|
||||||
|
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
|
@ -378,7 +373,7 @@ class Classifier2p5D(ParallelLayer):
|
||||||
|
|
||||||
def _set_tensor_parallel_attributes(self):
|
def _set_tensor_parallel_attributes(self):
|
||||||
if self.has_weight:
|
if self.has_weight:
|
||||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
|
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
|
||||||
|
|
||||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||||
with seed(ParallelMode.TENSOR):
|
with seed(ParallelMode.TENSOR):
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from ._operation import reduce_by_batch_3d, split_batch_3d
|
from ._operation import reduce_by_batch_3d, split_tensor_3d
|
||||||
from .layers import Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D
|
from .layers import Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'reduce_by_batch_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D'
|
'reduce_by_batch_3d', 'split_tensor_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D'
|
||||||
]
|
]
|
||||||
|
|
|
@ -175,10 +175,12 @@ class layernorm_3d(torch.autograd.Function):
|
||||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def split_batch_3d(input_: Tensor,
|
def split_tensor_3d(input_: Tensor,
|
||||||
input_parallel_mode: ParallelMode,
|
dim: int = 0,
|
||||||
weight_parallel_mode: ParallelMode,
|
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
|
||||||
dim: int = 0) -> Tensor:
|
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
|
||||||
|
if input_.size(dim) <= 1:
|
||||||
|
return input_
|
||||||
output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode),
|
output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode),
|
||||||
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
|
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
|
||||||
output = torch.chunk(output, gpc.get_world_size(input_parallel_mode),
|
output = torch.chunk(output, gpc.get_world_size(input_parallel_mode),
|
||||||
|
@ -189,15 +191,27 @@ def split_batch_3d(input_: Tensor,
|
||||||
class reduce_by_batch_3d(torch.autograd.Function):
|
class reduce_by_batch_3d(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float32)
|
@custom_fwd(cast_inputs=torch.float32)
|
||||||
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode) -> Tensor:
|
def forward(ctx,
|
||||||
|
input_: Tensor,
|
||||||
|
input_parallel_mode: ParallelMode,
|
||||||
|
weight_parallel_mode: ParallelMode,
|
||||||
|
reduce_mean: bool = False) -> Tensor:
|
||||||
output = all_reduce(input_, input_parallel_mode)
|
output = all_reduce(input_, input_parallel_mode)
|
||||||
output = all_reduce(output, weight_parallel_mode)
|
output = all_reduce(output, weight_parallel_mode)
|
||||||
|
ctx.reduce_mean = reduce_mean
|
||||||
|
if reduce_mean:
|
||||||
|
reduce_size = gpc.get_world_size(input_parallel_mode) * gpc.get_world_size(weight_parallel_mode)
|
||||||
|
ctx.reduce_size = reduce_size
|
||||||
|
return output.clone() / reduce_size
|
||||||
return output.clone()
|
return output.clone()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_bwd
|
@custom_bwd
|
||||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
return output_grad, None, None
|
if ctx.reduce_mean:
|
||||||
|
return output_grad / ctx.reduce_size, None, None, None
|
||||||
|
else:
|
||||||
|
return output_grad, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class broadcast_weight_3d_from_diagonal(torch.autograd.Function):
|
class broadcast_weight_3d_from_diagonal(torch.autograd.Function):
|
||||||
|
|
|
@ -17,9 +17,9 @@ from colossalai.utils import get_current_device
|
||||||
from torch import Tensor, dtype
|
from torch import Tensor, dtype
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
|
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||||
from ._operation import *
|
from ._operation import *
|
||||||
from ._utils import (get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group)
|
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
|
||||||
|
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
|
@ -241,8 +241,6 @@ class PatchEmbedding3D(ParallelLayer):
|
||||||
self.pos_embed.register_hook(self._sync_grad_hook)
|
self.pos_embed.register_hook(self._sync_grad_hook)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode)
|
|
||||||
|
|
||||||
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
|
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
|
||||||
self.weight_parallel_mode, self.output_parallel_mode)
|
self.weight_parallel_mode, self.output_parallel_mode)
|
||||||
output = F.conv2d(input_, weight, self.bias, stride=self.patch_size)
|
output = F.conv2d(input_, weight, self.bias, stride=self.patch_size)
|
||||||
|
@ -302,8 +300,6 @@ class Embedding3D(ParallelLayer):
|
||||||
self.weight[self.padding_idx].fill_(0)
|
self.weight[self.padding_idx].fill_(0)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode)
|
|
||||||
|
|
||||||
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
|
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
|
||||||
self.weight_parallel_mode, self.output_parallel_mode)
|
self.weight_parallel_mode, self.output_parallel_mode)
|
||||||
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
from .common import (ACT2FN, CheckpointModule, _ntuple, divide, get_tensor_parallel_mode,
|
||||||
|
set_tensor_parallel_attribute_by_partition, set_tensor_parallel_attribute_by_size, to_2tuple)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size',
|
||||||
|
'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple'
|
||||||
|
]
|
|
@ -2,11 +2,12 @@
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import collections.abc
|
import collections.abc
|
||||||
|
import os
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
|
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_MODE)
|
||||||
from colossalai.utils import checkpoint
|
from colossalai.utils import checkpoint
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
@ -59,6 +60,10 @@ def set_tensor_parallel_attribute_by_partition(param, num_partitions):
|
||||||
setattr(param, NUM_PARTITIONS, num_partitions)
|
setattr(param, NUM_PARTITIONS, num_partitions)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensor_parallel_mode():
|
||||||
|
return os.environ[TENSOR_PARALLEL_MODE]
|
||||||
|
|
||||||
|
|
||||||
# From PyTorch internals
|
# From PyTorch internals
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ from colossalai.utils import get_current_device
|
||||||
from torch import Tensor, dtype
|
from torch import Tensor, dtype
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
from .._common_utils import to_2tuple
|
from ..utils import to_2tuple
|
||||||
|
|
||||||
|
|
||||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||||
|
|
|
@ -2,6 +2,7 @@ from torch import nn
|
||||||
from torch.nn.modules.loss import *
|
from torch.nn.modules.loss import *
|
||||||
from torch.nn.modules.loss import _Loss
|
from torch.nn.modules.loss import _Loss
|
||||||
|
|
||||||
|
from colossalai.nn.layer.utils import get_tensor_parallel_mode
|
||||||
from .loss_2d import CrossEntropyLoss2D
|
from .loss_2d import CrossEntropyLoss2D
|
||||||
from .loss_2p5d import CrossEntropyLoss2p5D
|
from .loss_2p5d import CrossEntropyLoss2p5D
|
||||||
from .loss_3d import CrossEntropyLoss3D
|
from .loss_3d import CrossEntropyLoss3D
|
||||||
|
@ -14,9 +15,10 @@ _parallel_cross_entropy = {
|
||||||
|
|
||||||
|
|
||||||
class CrossEntropyLoss(_Loss):
|
class CrossEntropyLoss(_Loss):
|
||||||
def __init__(self, reduction: bool = True, tensor_parallel: str = None, *args, **kwargs):
|
def __init__(self, reduction: bool = True, *args, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if tensor_parallel in [None, '1d']:
|
tensor_parallel = get_tensor_parallel_mode()
|
||||||
|
if tensor_parallel in ['None', '1d']:
|
||||||
reduction = 'mean' if reduction else 'none'
|
reduction = 'mean' if reduction else 'none'
|
||||||
self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
|
self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
|
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d
|
||||||
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
|
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
|
||||||
from colossalai.registry import LOSSES
|
from colossalai.registry import LOSSES
|
||||||
from torch.nn.functional import cross_entropy
|
from torch.nn.functional import cross_entropy
|
||||||
|
@ -20,11 +20,8 @@ class CrossEntropyLoss2D(_Loss):
|
||||||
self.loss_kwargs = kwargs
|
self.loss_kwargs = kwargs
|
||||||
|
|
||||||
def forward(self, logits, targets):
|
def forward(self, logits, targets):
|
||||||
batch_size = targets.size(0)
|
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
|
||||||
targets = split_batch_2d(targets)
|
|
||||||
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
|
|
||||||
if self.reduction_mean:
|
if self.reduction_mean:
|
||||||
loss = loss.sum()
|
loss = loss.mean()
|
||||||
loss = reduce_by_batch_2d.apply(loss)
|
loss = reduce_by_batch_2d.apply(loss, True)
|
||||||
loss /= batch_size
|
|
||||||
return loss
|
return loss
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
|
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d
|
||||||
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
|
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
|
||||||
from colossalai.registry import LOSSES
|
from colossalai.registry import LOSSES
|
||||||
from torch.nn.functional import cross_entropy
|
from torch.nn.functional import cross_entropy
|
||||||
|
@ -19,11 +19,8 @@ class CrossEntropyLoss2p5D(_Loss):
|
||||||
self.loss_kwargs = kwargs
|
self.loss_kwargs = kwargs
|
||||||
|
|
||||||
def forward(self, logits, targets):
|
def forward(self, logits, targets):
|
||||||
batch_size = targets.size(0)
|
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
|
||||||
targets = split_batch_2p5d(targets)
|
|
||||||
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
|
|
||||||
if self.reduction_mean:
|
if self.reduction_mean:
|
||||||
loss = loss.sum()
|
loss = loss.mean()
|
||||||
loss = reduce_by_batch_2p5d.apply(loss)
|
loss = reduce_by_batch_2p5d.apply(loss, True)
|
||||||
loss /= batch_size
|
|
||||||
return loss
|
return loss
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||||
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_batch_3d
|
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d
|
||||||
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
|
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
|
||||||
from colossalai.registry import LOSSES
|
from colossalai.registry import LOSSES
|
||||||
from torch.nn.functional import cross_entropy
|
from torch.nn.functional import cross_entropy
|
||||||
from torch.nn.modules.loss import _Loss
|
from torch.nn.modules.loss import _Loss
|
||||||
|
|
||||||
|
|
||||||
@LOSSES.register_module
|
@LOSSES.register_module
|
||||||
class CrossEntropyLoss3D(_Loss):
|
class CrossEntropyLoss3D(_Loss):
|
||||||
"""Cross entropy loss for 3D parallelism
|
"""Cross entropy loss for 3D parallelism
|
||||||
|
@ -28,11 +27,8 @@ class CrossEntropyLoss3D(_Loss):
|
||||||
self.loss_kwargs = kwargs
|
self.loss_kwargs = kwargs
|
||||||
|
|
||||||
def forward(self, logits, targets):
|
def forward(self, logits, targets):
|
||||||
batch_size = targets.size(0)
|
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
|
||||||
targets = split_batch_3d(targets, self.input_parallel_mode, self.weight_parallel_mode)
|
|
||||||
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
|
|
||||||
if self.reduction_mean:
|
if self.reduction_mean:
|
||||||
loss = loss.sum()
|
loss = loss.mean()
|
||||||
loss = reduce_by_batch_3d.apply(loss, self.input_parallel_mode, self.weight_parallel_mode)
|
loss = reduce_by_batch_3d.apply(loss, self.input_parallel_mode, self.weight_parallel_mode, True)
|
||||||
loss /= batch_size
|
|
||||||
return loss
|
return loss
|
||||||
|
|
|
@ -4,6 +4,7 @@ from ._utils import calc_acc
|
||||||
from .accuracy_2d import Accuracy2D
|
from .accuracy_2d import Accuracy2D
|
||||||
from .accuracy_2p5d import Accuracy2p5D
|
from .accuracy_2p5d import Accuracy2p5D
|
||||||
from .accuracy_3d import Accuracy3D
|
from .accuracy_3d import Accuracy3D
|
||||||
|
from colossalai.nn.layer.utils import get_tensor_parallel_mode
|
||||||
|
|
||||||
_parallel_accuracy = {
|
_parallel_accuracy = {
|
||||||
'2d': Accuracy2D,
|
'2d': Accuracy2D,
|
||||||
|
@ -13,9 +14,10 @@ _parallel_accuracy = {
|
||||||
|
|
||||||
|
|
||||||
class Accuracy(nn.Module):
|
class Accuracy(nn.Module):
|
||||||
def __init__(self, tensor_parallel: str = None):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if tensor_parallel in [None, '1d']:
|
tensor_parallel = get_tensor_parallel_mode()
|
||||||
|
if tensor_parallel in ['None', '1d']:
|
||||||
self.acc = calc_acc
|
self.acc = calc_acc
|
||||||
else:
|
else:
|
||||||
self.acc = _parallel_accuracy[tensor_parallel]()
|
self.acc = _parallel_accuracy[tensor_parallel]()
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
|
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ._utils import calc_acc
|
from ._utils import calc_acc
|
||||||
|
@ -11,7 +11,6 @@ class Accuracy2D(nn.Module):
|
||||||
|
|
||||||
def forward(self, logits, targets):
|
def forward(self, logits, targets):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
targets = split_batch_2d(targets)
|
|
||||||
correct = calc_acc(logits, targets)
|
correct = calc_acc(logits, targets)
|
||||||
correct = reduce_by_batch_2d.apply(correct)
|
correct = reduce_by_batch_2d.apply(correct)
|
||||||
return correct
|
return correct
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
|
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ._utils import calc_acc
|
from ._utils import calc_acc
|
||||||
|
@ -11,7 +11,6 @@ class Accuracy2p5D(nn.Module):
|
||||||
|
|
||||||
def forward(self, logits, targets):
|
def forward(self, logits, targets):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
targets = split_batch_2p5d(targets)
|
|
||||||
correct = calc_acc(logits, targets)
|
correct = calc_acc(logits, targets)
|
||||||
correct = reduce_by_batch_2p5d.apply(correct)
|
correct = reduce_by_batch_2p5d.apply(correct)
|
||||||
return correct
|
return correct
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||||
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_batch_3d
|
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d
|
||||||
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
|
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
@ -15,7 +15,6 @@ class Accuracy3D(nn.Module):
|
||||||
|
|
||||||
def forward(self, logits, targets):
|
def forward(self, logits, targets):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
targets = split_batch_3d(targets, self.input_parallel_mode, self.weight_parallel_mode)
|
|
||||||
correct = calc_acc(logits, targets)
|
correct = calc_acc(logits, targets)
|
||||||
correct = reduce_by_batch_3d.apply(correct, self.input_parallel_mode, self.weight_parallel_mode)
|
correct = reduce_by_batch_3d.apply(correct, self.input_parallel_mode, self.weight_parallel_mode)
|
||||||
return correct
|
return correct
|
||||||
|
|
|
@ -173,7 +173,7 @@ class AccuracyMetric(Metric):
|
||||||
self.accumulated_sum.zero_()
|
self.accumulated_sum.zero_()
|
||||||
self.accumulated_correct.zero_()
|
self.accumulated_correct.zero_()
|
||||||
|
|
||||||
def update(self, logits, targets) -> None:
|
def update(self, logits, targets, batch_size) -> None:
|
||||||
"""Updates last step accuracy and accumulated accuracy with current logits
|
"""Updates last step accuracy and accumulated accuracy with current logits
|
||||||
and labels. It expects the output has logits and labels.
|
and labels. It expects the output has logits and labels.
|
||||||
|
|
||||||
|
@ -187,7 +187,7 @@ class AccuracyMetric(Metric):
|
||||||
# update
|
# update
|
||||||
correct = self.acc(logits, targets)
|
correct = self.acc(logits, targets)
|
||||||
|
|
||||||
self.last_step_sum.fill_(targets.size(0))
|
self.last_step_sum.fill_(batch_size)
|
||||||
self.last_step_correct.fill_(correct)
|
self.last_step_correct.fill_(correct)
|
||||||
self.accumulated_sum += self.last_step_sum
|
self.accumulated_sum += self.last_step_sum
|
||||||
self.accumulated_correct += self.last_step_correct
|
self.accumulated_correct += self.last_step_correct
|
||||||
|
@ -296,7 +296,8 @@ class AccuracyHook(MetricHook):
|
||||||
|
|
||||||
def after_test_iter(self, trainer, logits, targets, *args):
|
def after_test_iter(self, trainer, logits, targets, *args):
|
||||||
if self._is_stage_to_compute:
|
if self._is_stage_to_compute:
|
||||||
self.metric.update(logits, targets)
|
batch_size = trainer.schedule.batch_size
|
||||||
|
self.metric.update(logits, targets, batch_size)
|
||||||
|
|
||||||
|
|
||||||
class ThroughputMetric(Metric):
|
class ThroughputMetric(Metric):
|
||||||
|
@ -313,10 +314,8 @@ class ThroughputMetric(Metric):
|
||||||
self.last_step_num_samples.zero_()
|
self.last_step_num_samples.zero_()
|
||||||
self.last_step_used_time.zero_()
|
self.last_step_used_time.zero_()
|
||||||
|
|
||||||
def update(self, tensor, time) -> None:
|
def update(self, num_samples, time) -> None:
|
||||||
if isinstance(tensor, (list, tuple)):
|
self.last_step_num_samples.fill_(num_samples)
|
||||||
tensor = tensor[0]
|
|
||||||
self.last_step_num_samples.fill_(tensor.size(0))
|
|
||||||
self.last_step_used_time.fill_(time)
|
self.last_step_used_time.fill_(time)
|
||||||
self.accumulated_num_samples += self.last_step_num_samples
|
self.accumulated_num_samples += self.last_step_num_samples
|
||||||
self.accumulated_used_time += self.last_step_used_time
|
self.accumulated_used_time += self.last_step_used_time
|
||||||
|
@ -354,11 +353,11 @@ class ThroughputHook(MetricHook):
|
||||||
def before_train_epoch(self, trainer):
|
def before_train_epoch(self, trainer):
|
||||||
self.metric.reset()
|
self.metric.reset()
|
||||||
|
|
||||||
def after_train_iter(self, trainer, logits, targets, *args):
|
def after_train_iter(self, trainer, *args):
|
||||||
self.metric.update(targets, trainer._timer.get_timer('Train-step').get_elapsed_time())
|
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
|
||||||
|
|
||||||
def before_test(self, trainer):
|
def before_test(self, trainer):
|
||||||
self.metric.reset()
|
self.metric.reset()
|
||||||
|
|
||||||
def after_test_iter(self, trainer, logits, targets, *args):
|
def after_test_iter(self, trainer, *args):
|
||||||
self.metric.update(targets, trainer._timer.get_timer('Test-step').get_elapsed_time())
|
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())
|
||||||
|
|
|
@ -1,27 +1,19 @@
|
||||||
from .activation_checkpoint import checkpoint
|
from .activation_checkpoint import checkpoint
|
||||||
from .common import (print_rank_0, sync_model_param_in_dp, is_dp_rank_0,
|
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
|
||||||
is_tp_rank_0, is_no_pp_or_last_stage, is_using_ddp,
|
free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0,
|
||||||
is_using_pp, conditional_context, is_model_parallel_parameter,
|
is_using_ddp, is_using_pp, multi_tensor_applier, param_is_not_tensor_parallel_duplicate,
|
||||||
clip_grad_norm_fp32, count_zeros_fp32, copy_tensor_parallel_attributes,
|
print_rank_0, switch_virtual_pipeline_parallel_rank, sync_model_param_in_dp)
|
||||||
param_is_not_tensor_parallel_duplicate, switch_virtual_pipeline_parallel_rank)
|
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
|
||||||
from .cuda import get_current_device, synchronize, empty_cache, set_to_cuda
|
from .data_sampler import DataParallelSampler, get_dataloader
|
||||||
|
from .gradient_accumulation import accumulate_gradient
|
||||||
from .memory import report_memory_usage
|
from .memory import report_memory_usage
|
||||||
from .timer import MultiTimer, Timer
|
from .timer import MultiTimer, Timer
|
||||||
from .multi_tensor_apply import multi_tensor_applier
|
|
||||||
from .gradient_accumulation import accumulate_gradient
|
|
||||||
from .data_sampler import DataParallelSampler, get_dataloader
|
|
||||||
|
|
||||||
__all__ = ['checkpoint',
|
__all__ = [
|
||||||
'print_rank_0', 'sync_model_param_in_dp', 'is_dp_rank_0',
|
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param_in_dp', 'is_dp_rank_0', 'is_tp_rank_0',
|
||||||
'is_tp_rank_0', 'is_no_pp_or_last_stage', 'is_using_ddp',
|
'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'conditional_context', 'is_model_parallel_parameter',
|
||||||
'is_using_pp', 'conditional_context', 'is_model_parallel_parameter',
|
'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
|
||||||
'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
|
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
||||||
'param_is_not_tensor_parallel_duplicate',
|
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
|
||||||
'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
'get_dataloader', 'switch_virtual_pipeline_parallel_rank'
|
||||||
'report_memory_usage',
|
]
|
||||||
'Timer', 'MultiTimer',
|
|
||||||
'multi_tensor_applier',
|
|
||||||
'accumulate_gradient',
|
|
||||||
'DataParallelSampler', 'get_dataloader',
|
|
||||||
'switch_virtual_pipeline_parallel_rank'
|
|
||||||
]
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
import random
|
||||||
|
import socket
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._six import inf
|
from torch._six import inf
|
||||||
|
@ -9,16 +11,15 @@ try:
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
import torch.distributed as dist
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from .multi_tensor_apply import multi_tensor_applier
|
|
||||||
from colossalai.constants import IS_TENSOR_PARALLEL, TENSOR_PARALLEL_ATTRIBUTES, NUM_PARTITIONS
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
|
from .multi_tensor_apply import multi_tensor_applier
|
||||||
|
|
||||||
|
|
||||||
def print_rank_0(msg: str, logger=None):
|
def print_rank_0(msg: str, logger=None):
|
||||||
'''Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
|
'''Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
|
||||||
|
@ -33,6 +34,18 @@ def print_rank_0(msg: str, logger=None):
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def free_port():
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
sock = socket.socket()
|
||||||
|
port = random.randint(20000, 65000)
|
||||||
|
sock.bind(('localhost', port))
|
||||||
|
sock.close()
|
||||||
|
return port
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
def sync_model_param_in_dp(model):
|
def sync_model_param_in_dp(model):
|
||||||
'''Make sure data parameters are consistent during Data Parallel Mode
|
'''Make sure data parameters are consistent during Data Parallel Mode
|
||||||
|
|
||||||
|
|
|
@ -3,9 +3,8 @@ from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from colossalai import nn as col_nn
|
from colossalai import nn as col_nn
|
||||||
from colossalai.context import ParallelMode, seed
|
from colossalai.nn.layer.utils import CheckpointModule
|
||||||
from colossalai.registry import LAYERS, MODELS
|
from colossalai.registry import LAYERS, MODELS
|
||||||
from colossalai.utils import checkpoint
|
|
||||||
from torch import dtype, nn
|
from torch import dtype, nn
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -72,8 +71,7 @@ class ViTEmbedding(nn.Module):
|
||||||
dropout: float,
|
dropout: float,
|
||||||
dtype: dtype = None,
|
dtype: dtype = None,
|
||||||
flatten: bool = True,
|
flatten: bool = True,
|
||||||
init_method: str = 'torch',
|
init_method: str = 'torch'):
|
||||||
tensor_parallel: str = None):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.patch_embed = col_nn.PatchEmbedding(img_size,
|
self.patch_embed = col_nn.PatchEmbedding(img_size,
|
||||||
patch_size,
|
patch_size,
|
||||||
|
@ -81,19 +79,17 @@ class ViTEmbedding(nn.Module):
|
||||||
embedding_dim,
|
embedding_dim,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
flatten=flatten,
|
flatten=flatten,
|
||||||
tensor_parallel=tensor_parallel,
|
|
||||||
**_init_rules[init_method]['embed'])
|
**_init_rules[init_method]['embed'])
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = col_nn.Dropout(dropout)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
with seed(ParallelMode.TENSOR):
|
x = self.dropout(x)
|
||||||
x = self.dropout(x)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class ViTSelfAttention(nn.Module):
|
class ViTSelfAttention(CheckpointModule):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
dim: int,
|
dim: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
|
@ -102,27 +98,17 @@ class ViTSelfAttention(nn.Module):
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
dtype: dtype = None,
|
dtype: dtype = None,
|
||||||
checkpoint: bool = False,
|
checkpoint: bool = False,
|
||||||
init_method: str = 'torch',
|
init_method: str = 'torch'):
|
||||||
tensor_parallel: str = None):
|
super().__init__(checkpoint)
|
||||||
super().__init__()
|
|
||||||
self.attention_head_size = dim // num_heads
|
self.attention_head_size = dim // num_heads
|
||||||
self.checkpoint = checkpoint
|
|
||||||
self.tensor_parallel = tensor_parallel
|
|
||||||
|
|
||||||
self.query_key_value = col_nn.Linear(dim,
|
self.query_key_value = col_nn.Linear(dim,
|
||||||
3 * dim,
|
3 * dim,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
tensor_parallel='1d_col' if tensor_parallel == '1d' else tensor_parallel,
|
|
||||||
**_init_rules[init_method]['transformer'])
|
**_init_rules[init_method]['transformer'])
|
||||||
self.attention_dropout = nn.Dropout(attention_dropout)
|
self.attention_dropout = col_nn.Dropout(attention_dropout)
|
||||||
self.dense = col_nn.Linear(dim,
|
self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True, **_init_rules[init_method]['transformer'])
|
||||||
dim,
|
self.dropout = col_nn.Dropout(dropout)
|
||||||
dtype=dtype,
|
|
||||||
bias=True,
|
|
||||||
tensor_parallel='1d_row' if tensor_parallel == '1d' else tensor_parallel,
|
|
||||||
**_init_rules[init_method]['transformer'])
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
self.softmax = nn.Softmax(dim=-1)
|
self.softmax = nn.Softmax(dim=-1)
|
||||||
|
|
||||||
def _forward(self, x):
|
def _forward(self, x):
|
||||||
|
@ -138,8 +124,7 @@ class ViTSelfAttention(nn.Module):
|
||||||
x = torch.matmul(q, k.transpose(-1, -2))
|
x = torch.matmul(q, k.transpose(-1, -2))
|
||||||
x = x / math.sqrt(self.attention_head_size)
|
x = x / math.sqrt(self.attention_head_size)
|
||||||
x = self.softmax(x)
|
x = self.softmax(x)
|
||||||
with seed(ParallelMode.TENSOR):
|
x = self.attention_dropout(x)
|
||||||
x = self.attention_dropout(x)
|
|
||||||
|
|
||||||
x = torch.matmul(x, v)
|
x = torch.matmul(x, v)
|
||||||
x = x.transpose(1, 2)
|
x = x.transpose(1, 2)
|
||||||
|
@ -147,26 +132,13 @@ class ViTSelfAttention(nn.Module):
|
||||||
x = x.reshape(new_context_layer_shape)
|
x = x.reshape(new_context_layer_shape)
|
||||||
|
|
||||||
x = self.dense(x)
|
x = self.dense(x)
|
||||||
if self.tensor_parallel == '1d':
|
x = self.dropout(x)
|
||||||
x = self.dropout(x)
|
|
||||||
else:
|
|
||||||
with seed(ParallelMode.TENSOR):
|
|
||||||
x = self.dropout(x)
|
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def _checkpoint_forward(self, x):
|
|
||||||
return checkpoint(self._forward, x)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.checkpoint:
|
|
||||||
return self._checkpoint_forward(x)
|
|
||||||
else:
|
|
||||||
return self._forward(x)
|
|
||||||
|
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class ViTMLP(nn.Module):
|
class ViTMLP(CheckpointModule):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
dim: int,
|
dim: int,
|
||||||
mlp_ratio: int,
|
mlp_ratio: int,
|
||||||
|
@ -175,50 +147,30 @@ class ViTMLP(nn.Module):
|
||||||
dtype: dtype = None,
|
dtype: dtype = None,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
checkpoint: bool = False,
|
checkpoint: bool = False,
|
||||||
init_method: str = 'torch',
|
init_method: str = 'torch'):
|
||||||
tensor_parallel: str = None):
|
super().__init__(checkpoint)
|
||||||
super().__init__()
|
|
||||||
self.checkpoint = checkpoint
|
|
||||||
self.tensor_parallel = tensor_parallel
|
|
||||||
|
|
||||||
self.dense_1 = col_nn.Linear(dim,
|
self.dense_1 = col_nn.Linear(dim,
|
||||||
mlp_ratio * dim,
|
mlp_ratio * dim,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
tensor_parallel='1d_col' if tensor_parallel == '1d' else tensor_parallel,
|
|
||||||
**_init_rules[init_method]['transformer'])
|
**_init_rules[init_method]['transformer'])
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
|
self.dropout_1 = col_nn.Dropout(dropout)
|
||||||
self.dense_2 = col_nn.Linear(mlp_ratio * dim,
|
self.dense_2 = col_nn.Linear(mlp_ratio * dim,
|
||||||
dim,
|
dim,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
tensor_parallel='1d_row' if tensor_parallel == '1d' else tensor_parallel,
|
|
||||||
**_init_rules[init_method]['transformer'])
|
**_init_rules[init_method]['transformer'])
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout_2 = col_nn.Dropout(dropout)
|
||||||
|
|
||||||
def _forward(self, x):
|
def _forward(self, x):
|
||||||
x = self.dense_1(x)
|
x = self.dense_1(x)
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
with seed(ParallelMode.TENSOR):
|
x = self.dropout_1(x)
|
||||||
x = self.dropout(x)
|
|
||||||
x = self.dense_2(x)
|
x = self.dense_2(x)
|
||||||
if self.tensor_parallel == '1d':
|
x = self.dropout_2(x)
|
||||||
x = self.dropout(x)
|
|
||||||
else:
|
|
||||||
with seed(ParallelMode.TENSOR):
|
|
||||||
x = self.dropout(x)
|
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def _checkpoint_forward(self, x):
|
|
||||||
return checkpoint(self._forward, x)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.checkpoint:
|
|
||||||
return self._checkpoint_forward(x)
|
|
||||||
else:
|
|
||||||
return self._forward(x)
|
|
||||||
|
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class ViTHead(nn.Module):
|
class ViTHead(nn.Module):
|
||||||
|
@ -228,19 +180,14 @@ class ViTHead(nn.Module):
|
||||||
representation_size: int = None,
|
representation_size: int = None,
|
||||||
dtype: dtype = None,
|
dtype: dtype = None,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
init_method: str = 'torch',
|
init_method: str = 'torch'):
|
||||||
tensor_parallel: str = None):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if representation_size:
|
if representation_size:
|
||||||
tensor_parallel_kwargs = {'tensor_parallel': '1d_col' if tensor_parallel == '1d' else tensor_parallel}
|
|
||||||
if tensor_parallel == '1d':
|
|
||||||
tensor_parallel_kwargs['gather_output'] = True
|
|
||||||
self.representation = col_nn.Linear(dim,
|
self.representation = col_nn.Linear(dim,
|
||||||
representation_size,
|
representation_size,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
**_init_rules[init_method]['head'],
|
**_init_rules[init_method]['head'])
|
||||||
**tensor_parallel_kwargs)
|
|
||||||
else:
|
else:
|
||||||
self.representation = None
|
self.representation = None
|
||||||
representation_size = dim
|
representation_size = dim
|
||||||
|
@ -249,7 +196,6 @@ class ViTHead(nn.Module):
|
||||||
num_classes,
|
num_classes,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
tensor_parallel=tensor_parallel,
|
|
||||||
**_init_rules[init_method]['head'])
|
**_init_rules[init_method]['head'])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -273,10 +219,9 @@ class ViTBlock(nn.Module):
|
||||||
dtype: dtype = None,
|
dtype: dtype = None,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
checkpoint: bool = False,
|
checkpoint: bool = False,
|
||||||
init_method: str = 'torch',
|
init_method: str = 'torch'):
|
||||||
tensor_parallel: str = None):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype, tensor_parallel=tensor_parallel)
|
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
|
||||||
self.attn = ViTSelfAttention(dim=dim,
|
self.attn = ViTSelfAttention(dim=dim,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
attention_dropout=attention_dropout,
|
attention_dropout=attention_dropout,
|
||||||
|
@ -284,10 +229,9 @@ class ViTBlock(nn.Module):
|
||||||
bias=bias,
|
bias=bias,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
checkpoint=checkpoint,
|
checkpoint=checkpoint,
|
||||||
init_method=init_method,
|
init_method=init_method)
|
||||||
tensor_parallel=tensor_parallel)
|
|
||||||
self.drop_path = col_nn.DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path = col_nn.DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype, tensor_parallel=tensor_parallel)
|
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
|
||||||
self.mlp = ViTMLP(dim=dim,
|
self.mlp = ViTMLP(dim=dim,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
|
@ -295,8 +239,7 @@ class ViTBlock(nn.Module):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
checkpoint=checkpoint,
|
checkpoint=checkpoint,
|
||||||
init_method=init_method,
|
init_method=init_method)
|
||||||
tensor_parallel=tensor_parallel)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||||
|
@ -323,20 +266,16 @@ class VisionTransformer(nn.Module):
|
||||||
dtype: dtype = None,
|
dtype: dtype = None,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
checkpoint: bool = False,
|
checkpoint: bool = False,
|
||||||
init_method: str = 'torch',
|
init_method: str = 'torch'):
|
||||||
tensor_parallel: str = None):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
embed = ViTEmbedding(
|
embed = ViTEmbedding(img_size=img_size,
|
||||||
img_size=img_size,
|
patch_size=patch_size,
|
||||||
patch_size=patch_size,
|
in_chans=in_chans,
|
||||||
in_chans=in_chans,
|
embedding_dim=dim,
|
||||||
embedding_dim=dim,
|
dropout=dropout,
|
||||||
dropout=dropout,
|
dtype=dtype,
|
||||||
dtype=dtype,
|
init_method=init_method)
|
||||||
init_method=init_method,
|
|
||||||
tensor_parallel=tensor_parallel,
|
|
||||||
)
|
|
||||||
|
|
||||||
# stochastic depth decay rule
|
# stochastic depth decay rule
|
||||||
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
|
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
|
||||||
|
@ -353,26 +292,17 @@ class VisionTransformer(nn.Module):
|
||||||
bias=bias,
|
bias=bias,
|
||||||
checkpoint=checkpoint,
|
checkpoint=checkpoint,
|
||||||
init_method=init_method,
|
init_method=init_method,
|
||||||
tensor_parallel=tensor_parallel,
|
|
||||||
) for i in range(depth)
|
) for i in range(depth)
|
||||||
]
|
]
|
||||||
|
|
||||||
norm = col_nn.LayerNorm(
|
norm = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
|
||||||
normalized_shape=dim,
|
|
||||||
eps=1e-6,
|
|
||||||
dtype=dtype,
|
|
||||||
tensor_parallel=tensor_parallel,
|
|
||||||
)
|
|
||||||
|
|
||||||
head = ViTHead(
|
head = ViTHead(dim=dim,
|
||||||
dim=dim,
|
num_classes=num_classes,
|
||||||
num_classes=num_classes,
|
representation_size=representation_size,
|
||||||
representation_size=representation_size,
|
dtype=dtype,
|
||||||
dtype=dtype,
|
bias=bias,
|
||||||
bias=bias,
|
init_method=init_method)
|
||||||
init_method=init_method,
|
|
||||||
tensor_parallel=tensor_parallel,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.layers = nn.Sequential(
|
self.layers = nn.Sequential(
|
||||||
embed,
|
embed,
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import time
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -9,7 +8,7 @@ from colossalai.communication import all_gather, all_reduce, reduce_scatter
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import free_port, get_current_device
|
||||||
|
|
||||||
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
|
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
|
||||||
|
|
||||||
|
@ -49,8 +48,8 @@ def check_all_reduce():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
def check_layer(rank, world_size):
|
def check_layer(rank, world_size, port):
|
||||||
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=30010, backend='nccl')
|
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
assert dist.get_rank() == gpc.get_global_rank()
|
assert dist.get_rank() == gpc.get_global_rank()
|
||||||
print('Rank {} / {}'.format(dist.get_rank(), dist.get_world_size()))
|
print('Rank {} / {}'.format(dist.get_rank(), dist.get_world_size()))
|
||||||
|
@ -66,7 +65,7 @@ def check_layer(rank, world_size):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_comm():
|
def test_comm():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(check_layer, world_size=world_size)
|
run_func = partial(check_layer, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,16 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from colossalai import launch
|
from colossalai import launch
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from functools import partial
|
from colossalai.utils import free_port
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2d_init.py').absolute()
|
CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2d_init.py').absolute()
|
||||||
|
|
||||||
|
@ -87,7 +88,7 @@ def test_2d_init():
|
||||||
test_fn = partial(init_2d,
|
test_fn = partial(init_2d,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
backend='gloo',
|
backend='gloo',
|
||||||
port='29900',
|
port=free_port(),
|
||||||
host='localhost'
|
host='localhost'
|
||||||
)
|
)
|
||||||
mp.spawn(test_fn, nprocs=world_size)
|
mp.spawn(test_fn, nprocs=world_size)
|
||||||
|
|
|
@ -7,10 +7,10 @@ from pathlib import Path
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2p5d_init.py').absolute()
|
CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2p5d_init.py').absolute()
|
||||||
|
|
||||||
|
@ -111,7 +111,7 @@ def test_2halfd_init():
|
||||||
test_fn = partial(init_2halfd,
|
test_fn = partial(init_2halfd,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
backend='gloo',
|
backend='gloo',
|
||||||
port='29901',
|
port=free_port(),
|
||||||
host='localhost'
|
host='localhost'
|
||||||
)
|
)
|
||||||
mp.spawn(test_fn, nprocs=world_size)
|
mp.spawn(test_fn, nprocs=world_size)
|
||||||
|
|
|
@ -7,11 +7,10 @@ from pathlib import Path
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_3d_init.py').absolute()
|
CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_3d_init.py').absolute()
|
||||||
|
|
||||||
|
@ -104,7 +103,7 @@ def test_3d_init():
|
||||||
test_fn = partial(init_3d,
|
test_fn = partial(init_3d,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
backend='gloo',
|
backend='gloo',
|
||||||
port='29902',
|
port=free_port(),
|
||||||
host='localhost'
|
host='localhost'
|
||||||
)
|
)
|
||||||
mp.spawn(test_fn, nprocs=world_size)
|
mp.spawn(test_fn, nprocs=world_size)
|
||||||
|
|
|
@ -13,7 +13,7 @@ from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn import Accuracy, LinearWarmupLR
|
from colossalai.nn import Accuracy, LinearWarmupLR
|
||||||
from colossalai.nn.loss import CrossEntropyLoss
|
from colossalai.nn.loss import CrossEntropyLoss
|
||||||
from colossalai.trainer import Trainer, hooks
|
from colossalai.trainer import Trainer, hooks
|
||||||
from colossalai.utils import MultiTimer, get_dataloader
|
from colossalai.utils import MultiTimer, free_port, get_dataloader
|
||||||
from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep
|
from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep
|
||||||
from model_zoo.vit import vit_tiny_patch4_32
|
from model_zoo.vit import vit_tiny_patch4_32
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
@ -27,12 +27,12 @@ CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
|
||||||
gradient_accumulation=2)
|
gradient_accumulation=2)
|
||||||
|
|
||||||
|
|
||||||
def run_trainer(rank, world_size):
|
def run_trainer(rank, world_size, port):
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=30000, backend='nccl')
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
model = vit_tiny_patch4_32(tensor_parallel='1d')
|
model = vit_tiny_patch4_32()
|
||||||
pipe_model = build_pipeline_model(model.layers, num_chunks=1)
|
pipe_model = build_pipeline_model(model.layers, num_chunks=1)
|
||||||
|
|
||||||
# build dataloaders
|
# build dataloaders
|
||||||
|
@ -54,7 +54,7 @@ def run_trainer(rank, world_size):
|
||||||
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=BATCH_SIZE, pin_memory=True)
|
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=BATCH_SIZE, pin_memory=True)
|
||||||
|
|
||||||
# build criterion
|
# build criterion
|
||||||
criterion = CrossEntropyLoss(tensor_parallel='1d')
|
criterion = CrossEntropyLoss()
|
||||||
|
|
||||||
# optimizer
|
# optimizer
|
||||||
optimizer = torch.optim.Adam(pipe_model.parameters(), lr=0.001, weight_decay=0)
|
optimizer = torch.optim.Adam(pipe_model.parameters(), lr=0.001, weight_decay=0)
|
||||||
|
@ -78,7 +78,6 @@ def run_trainer(rank, world_size):
|
||||||
hook_list = [
|
hook_list = [
|
||||||
hooks.LossHook(),
|
hooks.LossHook(),
|
||||||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
|
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
|
||||||
hooks.AccuracyHook(accuracy_func=Accuracy(tensor_parallel='1d')),
|
|
||||||
hooks.LogMetricByEpochHook(logger),
|
hooks.LogMetricByEpochHook(logger),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -95,7 +94,7 @@ def run_trainer(rank, world_size):
|
||||||
# @pytest.mark.skip("This test requires more than 8 GPUs, you should invoke this test script using test.sh provided manually")
|
# @pytest.mark.skip("This test requires more than 8 GPUs, you should invoke this test script using test.sh provided manually")
|
||||||
def test_hybrid_parallel():
|
def test_hybrid_parallel():
|
||||||
world_size = 8
|
world_size = 8
|
||||||
run_func = partial(run_trainer, world_size=world_size)
|
run_func = partial(run_trainer, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,25 +1,23 @@
|
||||||
# !/usr/bin/env python
|
# !/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import colossalai
|
|
||||||
import os
|
import os
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import colossalai
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import os.path as osp
|
|
||||||
from pathlib import Path
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
from torchvision import transforms
|
|
||||||
from torch.optim import Adam
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.amp import AMP_TYPE
|
from colossalai.amp import AMP_TYPE
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import report_memory_usage, get_dataloader
|
from colossalai.utils import free_port, get_dataloader, report_memory_usage
|
||||||
from torchvision.models import resnet18
|
from torch.optim import Adam
|
||||||
|
from torchvision import transforms
|
||||||
from torchvision.datasets import CIFAR10
|
from torchvision.datasets import CIFAR10
|
||||||
from functools import partial
|
from torchvision.models import resnet18
|
||||||
|
|
||||||
|
|
||||||
# Config
|
# Config
|
||||||
BATCH_SIZE = 128
|
BATCH_SIZE = 128
|
||||||
|
@ -38,14 +36,14 @@ CONFIG = dict(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_engine(rank, world_size):
|
def run_engine(rank, world_size, port):
|
||||||
# init dist env
|
# init dist env
|
||||||
colossalai.launch(
|
colossalai.launch(
|
||||||
config=CONFIG,
|
config=CONFIG,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
host='localhost',
|
host='localhost',
|
||||||
port=29910,
|
port=port,
|
||||||
backend='nccl'
|
backend='nccl'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -104,7 +102,7 @@ def run_engine(rank, world_size):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_engine():
|
def test_engine():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(run_engine, world_size=world_size)
|
run_func = partial(run_engine, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,23 +1,20 @@
|
||||||
import colossalai
|
|
||||||
import os
|
import os
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import colossalai
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import os.path as osp
|
|
||||||
from pathlib import Path
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
from torchvision import transforms
|
|
||||||
from torch.optim import Adam
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.amp import AMP_TYPE
|
from colossalai.amp import AMP_TYPE
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import report_memory_usage, get_dataloader
|
from colossalai.utils import free_port, get_dataloader, report_memory_usage
|
||||||
from colossalai.initialize import get_default_parser
|
from torch.optim import Adam
|
||||||
from torchvision.models import resnet18
|
from torchvision import transforms
|
||||||
from torchvision.datasets import CIFAR10
|
from torchvision.datasets import CIFAR10
|
||||||
from functools import partial
|
from torchvision.models import resnet18
|
||||||
|
|
||||||
|
|
||||||
# Config
|
# Config
|
||||||
BATCH_SIZE = 128
|
BATCH_SIZE = 128
|
||||||
|
@ -38,14 +35,14 @@ CONFIG = dict(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_engine(rank, world_size):
|
def run_engine(rank, world_size, port):
|
||||||
# init dist env
|
# init dist env
|
||||||
colossalai.launch(
|
colossalai.launch(
|
||||||
config=CONFIG,
|
config=CONFIG,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
host='localhost',
|
host='localhost',
|
||||||
port=29911,
|
port=port,
|
||||||
backend='nccl'
|
backend='nccl'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -104,7 +101,7 @@ def run_engine(rank, world_size):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_engine():
|
def test_engine():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(run_engine, world_size=world_size)
|
run_func = partial(run_engine, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,23 +1,19 @@
|
||||||
import colossalai
|
|
||||||
import os
|
import os
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import colossalai
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import os.path as osp
|
|
||||||
from pathlib import Path
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
from torchvision import transforms
|
|
||||||
from torch.optim import Adam
|
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.amp import AMP_TYPE
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import report_memory_usage, get_dataloader
|
from colossalai.utils import free_port, get_dataloader, report_memory_usage
|
||||||
from colossalai.initialize import get_default_parser
|
from torch.optim import Adam
|
||||||
from torchvision.models import resnet18
|
from torchvision import transforms
|
||||||
from torchvision.datasets import CIFAR10
|
from torchvision.datasets import CIFAR10
|
||||||
from functools import partial
|
from torchvision.models import resnet18
|
||||||
|
|
||||||
|
|
||||||
# Config
|
# Config
|
||||||
BATCH_SIZE = 128
|
BATCH_SIZE = 128
|
||||||
|
@ -35,14 +31,14 @@ CONFIG = dict(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_engine(rank, world_size):
|
def run_engine(rank, world_size, port):
|
||||||
# init dist env
|
# init dist env
|
||||||
colossalai.launch(
|
colossalai.launch(
|
||||||
config=CONFIG,
|
config=CONFIG,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
host='localhost',
|
host='localhost',
|
||||||
port=29912,
|
port=port,
|
||||||
backend='nccl'
|
backend='nccl'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -101,7 +97,7 @@ def run_engine(rank, world_size):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_engine():
|
def test_engine():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(run_engine, world_size=world_size)
|
run_func = partial(run_engine, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,23 +1,20 @@
|
||||||
import colossalai
|
|
||||||
import os
|
import os
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import colossalai
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import os.path as osp
|
|
||||||
from pathlib import Path
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
from torchvision import transforms
|
|
||||||
from torch.optim import Adam
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.amp import AMP_TYPE
|
from colossalai.amp import AMP_TYPE
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import report_memory_usage, get_dataloader
|
from colossalai.utils import free_port, get_dataloader, report_memory_usage
|
||||||
from colossalai.initialize import get_default_parser
|
from torch.optim import Adam
|
||||||
from torchvision.models import resnet18
|
from torchvision import transforms
|
||||||
from torchvision.datasets import CIFAR10
|
from torchvision.datasets import CIFAR10
|
||||||
from functools import partial
|
from torchvision.models import resnet18
|
||||||
|
|
||||||
|
|
||||||
# Config
|
# Config
|
||||||
BATCH_SIZE = 128
|
BATCH_SIZE = 128
|
||||||
|
@ -36,14 +33,14 @@ CONFIG = dict(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_engine(rank, world_size):
|
def run_engine(rank, world_size, port):
|
||||||
# init dist env
|
# init dist env
|
||||||
colossalai.launch(
|
colossalai.launch(
|
||||||
config=CONFIG,
|
config=CONFIG,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
host='localhost',
|
host='localhost',
|
||||||
port=29913,
|
port=port,
|
||||||
backend='nccl'
|
backend='nccl'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -102,7 +99,7 @@ def run_engine(rank, world_size):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_engine():
|
def test_engine():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(run_engine, world_size=world_size)
|
run_func = partial(run_engine, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from functools import partial
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
from checks_1d.check_layer_1d import *
|
from checks_1d.check_layer_1d import *
|
||||||
|
|
||||||
CONFIG = dict(
|
CONFIG = dict(
|
||||||
|
@ -21,12 +23,12 @@ CONFIG = dict(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_layer(rank, world_size):
|
def check_layer(rank, world_size, port):
|
||||||
launch(config=CONFIG,
|
launch(config=CONFIG,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
host='localhost',
|
host='localhost',
|
||||||
port=29920,
|
port=port,
|
||||||
backend='nccl')
|
backend='nccl')
|
||||||
|
|
||||||
check_linear_col()
|
check_linear_col()
|
||||||
|
@ -39,7 +41,7 @@ def check_layer(rank, world_size):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_1d():
|
def test_1d():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(check_layer, world_size=world_size)
|
run_func = partial(check_layer, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,16 +1,17 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
from checks_2d.check_layer_2d import *
|
from checks_2d.check_layer_2d import *
|
||||||
from checks_2d.check_operation_2d import *
|
from checks_2d.check_operation_2d import *
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
|
|
||||||
CONFIG = dict(
|
CONFIG = dict(
|
||||||
parallel=dict(
|
parallel=dict(
|
||||||
|
@ -34,12 +35,12 @@ def check_layer():
|
||||||
check_layernorm()
|
check_layernorm()
|
||||||
check_classifier()
|
check_classifier()
|
||||||
|
|
||||||
def check_layer_and_operation(rank, world_size):
|
def check_layer_and_operation(rank, world_size, port):
|
||||||
launch(config=CONFIG,
|
launch(config=CONFIG,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
host='localhost',
|
host='localhost',
|
||||||
port=29921,
|
port=port,
|
||||||
backend='nccl')
|
backend='nccl')
|
||||||
|
|
||||||
# check_operations()
|
# check_operations()
|
||||||
|
@ -51,7 +52,7 @@ def check_layer_and_operation(rank, world_size):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_2d():
|
def test_2d():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(check_layer_and_operation, world_size=world_size)
|
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from checks_2p5d.check_layer_2p5d import check_linear, check_layernorm, check_classifier
|
from colossalai.utils import free_port
|
||||||
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
|
from checks_2p5d.check_layer_2p5d import (check_classifier, check_layernorm,
|
||||||
|
check_linear)
|
||||||
|
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
|
||||||
|
|
||||||
CONFIG = dict(
|
CONFIG = dict(
|
||||||
parallel=dict(
|
parallel=dict(
|
||||||
|
@ -29,12 +31,12 @@ def check_layer():
|
||||||
check_classifier()
|
check_classifier()
|
||||||
|
|
||||||
|
|
||||||
def check_layer_and_operation(rank, world_size):
|
def check_layer_and_operation(rank, world_size, port):
|
||||||
launch(config=CONFIG,
|
launch(config=CONFIG,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
host='localhost',
|
host='localhost',
|
||||||
port=29922,
|
port=port,
|
||||||
backend='nccl')
|
backend='nccl')
|
||||||
|
|
||||||
check_operations()
|
check_operations()
|
||||||
|
@ -46,7 +48,7 @@ def check_layer_and_operation(rank, world_size):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_2p5d():
|
def test_2p5d():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(check_layer_and_operation, world_size=world_size)
|
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
from checks_3d.check_layer_3d import *
|
from checks_3d.check_layer_3d import *
|
||||||
|
|
||||||
|
@ -27,8 +28,8 @@ def check_layer():
|
||||||
# check_loss()
|
# check_loss()
|
||||||
|
|
||||||
|
|
||||||
def check_layer_and_operation(rank, world_size):
|
def check_layer_and_operation(rank, world_size, port):
|
||||||
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29923, backend='nccl')
|
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
check_layer()
|
check_layer()
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -37,7 +38,7 @@ def check_layer_and_operation(rank, world_size):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_3d():
|
def test_3d():
|
||||||
world_size = 8
|
world_size = 8
|
||||||
run_func = partial(check_layer_and_operation, world_size=world_size)
|
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,10 +4,11 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.initialize import launch, get_default_parser
|
from colossalai.initialize import launch
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from checks_seq.check_layer_seq import *
|
from checks_seq.check_layer_seq import *
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
|
|
||||||
CONFIG = dict(
|
CONFIG = dict(
|
||||||
|
@ -22,13 +23,13 @@ def check_layer():
|
||||||
check_selfattention()
|
check_selfattention()
|
||||||
|
|
||||||
|
|
||||||
def run_check_sequence(rank, world_size):
|
def run_check_sequence(rank, world_size, port):
|
||||||
# init dist
|
# init dist
|
||||||
launch(config=CONFIG,
|
launch(config=CONFIG,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
host='localhost',
|
host='localhost',
|
||||||
port=29924,
|
port=port,
|
||||||
backend='nccl')
|
backend='nccl')
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
logger.info('Distributed environment is initialzied.', ranks=[0])
|
logger.info('Distributed environment is initialzied.', ranks=[0])
|
||||||
|
@ -41,7 +42,7 @@ def run_check_sequence(rank, world_size):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_sequence():
|
def test_sequence():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(run_check_sequence, world_size=world_size)
|
run_func = partial(run_check_sequence, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
import model
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
BATCH_SIZE = 128
|
BATCH_SIZE = 128
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from colossalai.communication import (recv_backward, recv_forward,
|
from colossalai.communication import (recv_backward, recv_forward,
|
||||||
recv_tensor_meta, send_backward,
|
recv_tensor_meta, send_backward,
|
||||||
send_backward_recv_forward, send_forward,
|
send_backward_recv_forward, send_forward,
|
||||||
|
@ -15,8 +16,7 @@ from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import free_port, get_current_device
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 16
|
||||||
SEQ_LENGTH = 64
|
SEQ_LENGTH = 64
|
||||||
|
@ -123,13 +123,13 @@ def check_comm(size, rank, prev_rank, next_rank, up_group, down_group, logger):
|
||||||
check_forward_backward(tensor, grad, rank, logger)
|
check_forward_backward(tensor, grad, rank, logger)
|
||||||
|
|
||||||
|
|
||||||
def run_check(rank, world_size):
|
def run_check(rank, world_size, port):
|
||||||
launch(
|
launch(
|
||||||
config=CONFIG,
|
config=CONFIG,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
host='localhost',
|
host='localhost',
|
||||||
port=29932,
|
port=port,
|
||||||
backend='nccl'
|
backend='nccl'
|
||||||
)
|
)
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
@ -154,7 +154,7 @@ def run_check(rank, world_size):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_p2p():
|
def test_p2p():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(run_check, world_size=world_size)
|
run_func = partial(run_check, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,25 +3,24 @@ import os.path as osp
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from colossalai.builder.pipeline import build_pipeline_model_from_cfg
|
from colossalai.builder.pipeline import build_pipeline_model_from_cfg
|
||||||
from colossalai.core import global_context
|
from colossalai.core import global_context
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import model
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
DIR_PATH = osp.dirname(osp.realpath(__file__))
|
DIR_PATH = osp.dirname(osp.realpath(__file__))
|
||||||
CONFIG_PATH = osp.join(DIR_PATH, 'resnet_config.py')
|
CONFIG_PATH = osp.join(DIR_PATH, 'resnet_config.py')
|
||||||
|
|
||||||
|
|
||||||
def run_partition(rank, world_size):
|
def run_partition(rank, world_size, port):
|
||||||
launch(config=CONFIG_PATH,
|
launch(config=CONFIG_PATH,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
host='localhost',
|
host='localhost',
|
||||||
port=29933,
|
port=port,
|
||||||
backend='nccl'
|
backend='nccl'
|
||||||
)
|
)
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
@ -40,7 +39,7 @@ def run_partition(rank, world_size):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_partition():
|
def test_partition():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(run_partition, world_size=world_size)
|
run_func = partial(run_partition, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,26 +1,23 @@
|
||||||
# referenced from Megatron and used to testify communication
|
# referenced from Megatron and used to testify communication
|
||||||
|
|
||||||
import colossalai
|
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import colossalai
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import model
|
|
||||||
|
|
||||||
from colossalai.builder import build_pipeline_model_from_cfg
|
from colossalai.builder import build_pipeline_model_from_cfg
|
||||||
from colossalai.communication import p2p as p2p_communication
|
|
||||||
from colossalai.communication.utils import send_tensor_meta, recv_tensor_meta
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.initialize import launch
|
|
||||||
from colossalai.utils import print_rank_0, get_current_device, get_dataloader
|
|
||||||
from colossalai.engine.schedule import PipelineSchedule
|
from colossalai.engine.schedule import PipelineSchedule
|
||||||
from torchvision.datasets import CIFAR10
|
from colossalai.initialize import launch
|
||||||
|
from colossalai.utils import free_port, get_dataloader, print_rank_0
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from pathlib import Path
|
from torchvision.datasets import CIFAR10
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
|
import model
|
||||||
|
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
NUM_MICRO = 8
|
NUM_MICRO = 8
|
||||||
|
@ -30,12 +27,12 @@ DIR_PATH = osp.dirname(osp.realpath(__file__))
|
||||||
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
|
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
|
||||||
|
|
||||||
|
|
||||||
def run_schedule(rank, world_size):
|
def run_schedule(rank, world_size, port):
|
||||||
launch(config=CONFIG_PATH,
|
launch(config=CONFIG_PATH,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
host='localhost',
|
host='localhost',
|
||||||
port=29934,
|
port=port,
|
||||||
backend='nccl')
|
backend='nccl')
|
||||||
|
|
||||||
# build model
|
# build model
|
||||||
|
@ -86,7 +83,7 @@ def run_schedule(rank, world_size):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_pipeline_schedule():
|
def test_pipeline_schedule():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(run_schedule, world_size=world_size)
|
run_func = partial(run_schedule, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ from colossalai.amp.amp_type import AMP_TYPE
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.trainer import Trainer
|
from colossalai.trainer import Trainer
|
||||||
from colossalai.utils import MultiTimer, get_dataloader
|
from colossalai.utils import MultiTimer, free_port, get_dataloader
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.datasets import CIFAR10
|
from torchvision.datasets import CIFAR10
|
||||||
|
@ -26,8 +26,8 @@ CONFIG = dict(
|
||||||
fp16=dict(mode=AMP_TYPE.TORCH))
|
fp16=dict(mode=AMP_TYPE.TORCH))
|
||||||
|
|
||||||
|
|
||||||
def run_trainer_no_pipeline(rank, world_size):
|
def run_trainer_no_pipeline(rank, world_size, port):
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29930, backend='nccl')
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
# build model
|
# build model
|
||||||
model = resnet18(num_classes=10)
|
model = resnet18(num_classes=10)
|
||||||
|
@ -88,7 +88,7 @@ def run_trainer_no_pipeline(rank, world_size):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_trainer_no_pipeline():
|
def test_trainer_no_pipeline():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(run_trainer_no_pipeline, world_size=world_size)
|
run_func = partial(run_trainer_no_pipeline, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine.schedule import PipelineSchedule
|
from colossalai.engine.schedule import PipelineSchedule
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.trainer import Trainer
|
from colossalai.trainer import Trainer
|
||||||
from colossalai.utils import MultiTimer, get_dataloader
|
from colossalai.utils import MultiTimer, free_port, get_dataloader
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.datasets import CIFAR10
|
from torchvision.datasets import CIFAR10
|
||||||
|
@ -25,8 +25,8 @@ NUM_EPOCHS = 200
|
||||||
CONFIG = dict(parallel=dict(pipeline=2, ), )
|
CONFIG = dict(parallel=dict(pipeline=2, ), )
|
||||||
|
|
||||||
|
|
||||||
def run_trainer_with_pipeline(rank, world_size):
|
def run_trainer_with_pipeline(rank, world_size, port):
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29931, backend='nccl')
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
# build model
|
# build model
|
||||||
model = resnet18(num_classes=10)
|
model = resnet18(num_classes=10)
|
||||||
|
@ -99,7 +99,7 @@ def run_trainer_with_pipeline(rank, world_size):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_trainer_with_pipeline():
|
def test_trainer_with_pipeline():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(run_trainer_with_pipeline, world_size=world_size)
|
run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,21 +1,19 @@
|
||||||
import colossalai
|
|
||||||
import os
|
import os
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import colossalai
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
|
||||||
from torchvision import transforms
|
|
||||||
from torch.optim import Adam
|
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import report_memory_usage, get_dataloader
|
from colossalai.utils import free_port, get_dataloader
|
||||||
from colossalai.initialize import get_default_parser
|
from torch.optim import Adam
|
||||||
from torchvision.models import resnet18
|
from torchvision import transforms
|
||||||
from torchvision.datasets import CIFAR10
|
from torchvision.datasets import CIFAR10
|
||||||
|
from torchvision.models import resnet18
|
||||||
|
|
||||||
# Config
|
# Config
|
||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 16
|
||||||
|
@ -32,7 +30,7 @@ CONFIG = dict(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_no_pipeline(rank, world_size):
|
def run_no_pipeline(rank, world_size, port):
|
||||||
|
|
||||||
# init dist env
|
# init dist env
|
||||||
colossalai.launch(
|
colossalai.launch(
|
||||||
|
@ -40,7 +38,7 @@ def run_no_pipeline(rank, world_size):
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
host='localhost',
|
host='localhost',
|
||||||
port=29500,
|
port=port,
|
||||||
backend='nccl'
|
backend='nccl'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -110,7 +108,7 @@ def run_no_pipeline(rank, world_size):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_engine():
|
def test_engine():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
func = partial(run_no_pipeline, world_size=world_size)
|
func = partial(run_no_pipeline, world_size=world_size, port=free_port())
|
||||||
mp.spawn(func, nprocs=world_size)
|
mp.spawn(func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,18 +2,18 @@
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import pytest
|
from functools import partial
|
||||||
import torch
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.utils import get_dataloader
|
from colossalai.utils import free_port, get_dataloader
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.models import resnet18
|
|
||||||
from torchvision.datasets import CIFAR10
|
from torchvision.datasets import CIFAR10
|
||||||
from functools import partial
|
from torchvision.models import resnet18
|
||||||
|
|
||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 16
|
||||||
IMG_SIZE = 224
|
IMG_SIZE = 224
|
||||||
|
@ -34,12 +34,12 @@ CONFIG = dict(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config=CONFIG,
|
colossalai.launch(config=CONFIG,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
host='localhost',
|
host='localhost',
|
||||||
port=29940,
|
port=port,
|
||||||
backend='nccl')
|
backend='nccl')
|
||||||
|
|
||||||
# build model
|
# build model
|
||||||
|
@ -94,7 +94,7 @@ def run_dist(rank, world_size):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_zero_level_2():
|
def test_zero_level_2():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(run_dist, world_size=world_size)
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,18 +2,18 @@
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import pytest
|
from functools import partial
|
||||||
import torch
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.utils import get_dataloader
|
from colossalai.utils import free_port, get_dataloader
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.models import resnet18
|
|
||||||
from torchvision.datasets import CIFAR10
|
from torchvision.datasets import CIFAR10
|
||||||
from functools import partial
|
from torchvision.models import resnet18
|
||||||
|
|
||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 16
|
||||||
IMG_SIZE = 224
|
IMG_SIZE = 224
|
||||||
|
@ -46,12 +46,12 @@ CONFIG = dict(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config=CONFIG,
|
colossalai.launch(config=CONFIG,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
host='localhost',
|
host='localhost',
|
||||||
port=29941,
|
port=port,
|
||||||
backend='nccl')
|
backend='nccl')
|
||||||
|
|
||||||
# build model
|
# build model
|
||||||
|
@ -106,7 +106,7 @@ def run_dist(rank, world_size):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_zero_level_3():
|
def test_zero_level_3():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(run_dist, world_size=world_size)
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ import torch.multiprocessing as mp
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn import CrossEntropyLoss
|
from colossalai.nn import CrossEntropyLoss
|
||||||
from colossalai.utils import get_dataloader
|
from colossalai.utils import free_port, get_dataloader
|
||||||
from model_zoo.vit import vit_lite_depth7_patch4_32
|
from model_zoo.vit import vit_lite_depth7_patch4_32
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.datasets import CIFAR10
|
from torchvision.datasets import CIFAR10
|
||||||
|
@ -40,11 +40,11 @@ def train_epoch(engine, train_dataloader):
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
def run_2d_parallel_vision_transformer_level_2(rank, world_size):
|
def run_2d_parallel_vision_transformer_level_2(rank, world_size, port):
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29950, backend='nccl')
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
# build model
|
# build model
|
||||||
model = vit_lite_depth7_patch4_32(tensor_parallel='2d')
|
model = vit_lite_depth7_patch4_32()
|
||||||
|
|
||||||
# build dataloader# build dataloaders
|
# build dataloader# build dataloaders
|
||||||
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
|
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
|
||||||
|
@ -62,7 +62,7 @@ def run_2d_parallel_vision_transformer_level_2(rank, world_size):
|
||||||
|
|
||||||
# build optimizer and loss
|
# build optimizer and loss
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||||
criterion = CrossEntropyLoss(tensor_parallel='2d')
|
criterion = CrossEntropyLoss()
|
||||||
|
|
||||||
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
@ -90,7 +90,7 @@ def run_2d_parallel_vision_transformer_level_2(rank, world_size):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_2d_vit_zero_level_2():
|
def test_2d_vit_zero_level_2():
|
||||||
world_size = 8
|
world_size = 8
|
||||||
run_func = partial(run_2d_parallel_vision_transformer_level_2, world_size=world_size)
|
run_func = partial(run_2d_parallel_vision_transformer_level_2, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ import torch.multiprocessing as mp
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn import CrossEntropyLoss
|
from colossalai.nn import CrossEntropyLoss
|
||||||
from colossalai.utils import get_dataloader
|
from colossalai.utils import free_port, get_dataloader
|
||||||
from model_zoo.vit import vit_lite_depth7_patch4_32
|
from model_zoo.vit import vit_lite_depth7_patch4_32
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.datasets import CIFAR10
|
from torchvision.datasets import CIFAR10
|
||||||
|
@ -40,11 +40,11 @@ def train_epoch(engine, train_dataloader):
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
def run_2d_parallel_vision_transformer_level_3(rank, world_size):
|
def run_2d_parallel_vision_transformer_level_3(rank, world_size, port):
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29951, backend='nccl')
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
# build model
|
# build model
|
||||||
model = vit_lite_depth7_patch4_32(tensor_parallel='2d')
|
model = vit_lite_depth7_patch4_32()
|
||||||
|
|
||||||
# build dataloader# build dataloaders
|
# build dataloader# build dataloaders
|
||||||
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
|
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
|
||||||
|
@ -62,7 +62,7 @@ def run_2d_parallel_vision_transformer_level_3(rank, world_size):
|
||||||
|
|
||||||
# build optimizer and loss
|
# build optimizer and loss
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||||
criterion = CrossEntropyLoss(tensor_parallel='2d')
|
criterion = CrossEntropyLoss()
|
||||||
|
|
||||||
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
@ -91,7 +91,7 @@ def run_2d_parallel_vision_transformer_level_3(rank, world_size):
|
||||||
@pytest.mark.skip("Level 3 has unknown bug so skip this test for now")
|
@pytest.mark.skip("Level 3 has unknown bug so skip this test for now")
|
||||||
def test_3d_vit_zero_level_3():
|
def test_3d_vit_zero_level_3():
|
||||||
world_size = 8
|
world_size = 8
|
||||||
run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size)
|
run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue